summaryrefslogtreecommitdiff
path: root/src/passes/MemoryCopyFillLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/passes/MemoryCopyFillLowering.cpp')
-rw-r--r--src/passes/MemoryCopyFillLowering.cpp260
1 files changed, 260 insertions, 0 deletions
diff --git a/src/passes/MemoryCopyFillLowering.cpp b/src/passes/MemoryCopyFillLowering.cpp
new file mode 100644
index 000000000..5855e5450
--- /dev/null
+++ b/src/passes/MemoryCopyFillLowering.cpp
@@ -0,0 +1,260 @@
+/*
+ * Copyright 2024 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/names.h"
+#include "pass.h"
+#include "wasm-builder.h"
+#include "wasm.h"
+
+// Replace memory.copy and memory.fill with a call to a function that
+// implements the same semantics. This is intended to be used with LLVM output,
+// so anything considered undefined behavior in LLVM is ignored. (In
+// particular, pointer overflow is UB and not handled here).
+
+namespace wasm {
+struct MemoryCopyFillLowering
+ : public WalkerPass<PostWalker<MemoryCopyFillLowering>> {
+ bool needsMemoryCopy = false;
+ bool needsMemoryFill = false;
+ Name memCopyFuncName;
+ Name memFillFuncName;
+
+ void visitMemoryCopy(MemoryCopy* curr) {
+ assert(curr->destMemory ==
+ curr->sourceMemory); // multi-memory not supported.
+ Builder builder(*getModule());
+ replaceCurrent(builder.makeCall(
+ "__memory_copy", {curr->dest, curr->source, curr->size}, Type::none));
+ needsMemoryCopy = true;
+ }
+
+ void visitMemoryFill(MemoryFill* curr) {
+ Builder builder(*getModule());
+ replaceCurrent(builder.makeCall(
+ "__memory_fill", {curr->dest, curr->value, curr->size}, Type::none));
+ needsMemoryFill = true;
+ }
+
+ void run(Module* module) override {
+ if (!module->features.hasBulkMemory()) {
+ return;
+ }
+ if (module->features.hasMemory64() || module->features.hasMultiMemory()) {
+ Fatal()
+ << "Memory64 and multi-memory not supported by memory.copy lowering";
+ }
+
+ // Check for the presence of any passive data or table segments.
+ for (auto& segment : module->dataSegments) {
+ if (segment->isPassive) {
+ Fatal() << "memory.copy lowering should only be run on modules with "
+ "no passive segments";
+ }
+ }
+ for (auto& segment : module->elementSegments) {
+ if (!segment->table.is()) {
+ Fatal() << "memory.copy lowering should only be run on modules with"
+ " no passive segments";
+ }
+ }
+
+ // In order to introduce a call to a function, it must first exist, so
+ // create an empty stub.
+ Builder b(*module);
+
+ memCopyFuncName = Names::getValidFunctionName(*module, "__memory_copy");
+ memFillFuncName = Names::getValidFunctionName(*module, "__memory_fill");
+ auto memCopyFunc = b.makeFunction(
+ memCopyFuncName,
+ {{"dst", Type::i32}, {"src", Type::i32}, {"size", Type::i32}},
+ Signature({Type::i32, Type::i32, Type::i32}, {Type::none}),
+ {{"start", Type::i32},
+ {"end", Type::i32},
+ {"step", Type::i32},
+ {"i", Type::i32}});
+ memCopyFunc->body = b.makeBlock();
+ module->addFunction(memCopyFunc.release());
+ auto memFillFunc = b.makeFunction(
+ memFillFuncName,
+ {{"dst", Type::i32}, {"val", Type::i32}, {"size", Type::i32}},
+ Signature({Type::i32, Type::i32, Type::i32}, {Type::none}),
+ {});
+ memFillFunc->body = b.makeBlock();
+ module->addFunction(memFillFunc.release());
+
+ Super::run(module);
+
+ if (needsMemoryCopy) {
+ createMemoryCopyFunc(module);
+ } else {
+ module->removeFunction(memCopyFuncName);
+ }
+
+ if (needsMemoryFill) {
+ createMemoryFillFunc(module);
+ } else {
+ module->removeFunction(memFillFuncName);
+ }
+ module->features.disable(FeatureSet::BulkMemory);
+ }
+
+ void createMemoryCopyFunc(Module* module) {
+ Builder b(*module);
+ Index dst = 0, src = 1, size = 2, start = 3, end = 4, step = 5, i = 6;
+ Name memory = module->memories.front()->name;
+ Block* body = b.makeBlock();
+ // end = memory size in bytes
+ body->list.push_back(
+ b.makeLocalSet(end,
+ b.makeBinary(BinaryOp::MulInt32,
+ b.makeMemorySize(memory),
+ b.makeConst(Memory::kPageSize))));
+ // if dst + size > memsize or src + size > memsize, then trap.
+ body->list.push_back(b.makeIf(
+ b.makeBinary(BinaryOp::OrInt32,
+ b.makeBinary(BinaryOp::GtUInt32,
+ b.makeBinary(BinaryOp::AddInt32,
+ b.makeLocalGet(dst, Type::i32),
+ b.makeLocalGet(size, Type::i32)),
+ b.makeLocalGet(end, Type::i32)),
+ b.makeBinary(BinaryOp::GtUInt32,
+ b.makeBinary(BinaryOp::AddInt32,
+ b.makeLocalGet(src, Type::i32),
+ b.makeLocalGet(size, Type::i32)),
+ b.makeLocalGet(end, Type::i32))),
+ b.makeUnreachable()));
+ // start and end are the starting and past-the-end indexes
+ // if src < dest: start = size - 1, end = -1, step = -1
+ // else: start = 0, end = size, step = 1
+ body->list.push_back(
+ b.makeIf(b.makeBinary(BinaryOp::LtUInt32,
+ b.makeLocalGet(src, Type::i32),
+ b.makeLocalGet(dst, Type::i32)),
+ b.makeBlock({
+ b.makeLocalSet(start,
+ b.makeBinary(BinaryOp::SubInt32,
+ b.makeLocalGet(size, Type::i32),
+ b.makeConst(1))),
+ b.makeLocalSet(end, b.makeConst(-1U)),
+ b.makeLocalSet(step, b.makeConst(-1U)),
+ }),
+ b.makeBlock({
+ b.makeLocalSet(start, b.makeConst(0)),
+ b.makeLocalSet(end, b.makeLocalGet(size, Type::i32)),
+ b.makeLocalSet(step, b.makeConst(1)),
+ })));
+ // i = start
+ body->list.push_back(b.makeLocalSet(i, b.makeLocalGet(start, Type::i32)));
+ body->list.push_back(b.makeBlock(
+ "out",
+ b.makeLoop(
+ "copy",
+ b.makeBlock(
+ {// break if i == end
+ b.makeBreak("out",
+ nullptr,
+ b.makeBinary(BinaryOp::EqInt32,
+ b.makeLocalGet(i, Type::i32),
+ b.makeLocalGet(end, Type::i32))),
+ // dst[i] = src[i]
+ b.makeStore(1,
+ 0,
+ 1,
+ b.makeBinary(BinaryOp::AddInt32,
+ b.makeLocalGet(dst, Type::i32),
+ b.makeLocalGet(i, Type::i32)),
+ b.makeLoad(1,
+ false,
+ 0,
+ 1,
+ b.makeBinary(BinaryOp::AddInt32,
+ b.makeLocalGet(src, Type::i32),
+ b.makeLocalGet(i, Type::i32)),
+ Type::i32,
+ memory),
+ Type::i32,
+ memory),
+ // i += step
+ b.makeLocalSet(i,
+ b.makeBinary(BinaryOp::AddInt32,
+ b.makeLocalGet(i, Type::i32),
+ b.makeLocalGet(step, Type::i32))),
+ // loop
+ b.makeBreak("copy", nullptr)}))));
+ module->getFunction(memCopyFuncName)->body = body;
+ }
+
+ void createMemoryFillFunc(Module* module) {
+ Builder b(*module);
+ Index dst = 0, val = 1, size = 2;
+ Name memory = module->memories.front()->name;
+ Block* body = b.makeBlock();
+
+ // if dst + size > memsize in bytes, then trap.
+ body->list.push_back(
+ b.makeIf(b.makeBinary(BinaryOp::GtUInt32,
+ b.makeBinary(BinaryOp::AddInt32,
+ b.makeLocalGet(dst, Type::i32),
+ b.makeLocalGet(size, Type::i32)),
+ b.makeBinary(BinaryOp::MulInt32,
+ b.makeMemorySize(memory),
+ b.makeConst(Memory::kPageSize))),
+ b.makeUnreachable()));
+
+ body->list.push_back(b.makeBlock(
+ "out",
+ b.makeLoop(
+ "copy",
+ b.makeBlock(
+ {// break if size == 0
+ b.makeBreak(
+ "out",
+ nullptr,
+ b.makeUnary(UnaryOp::EqZInt32, b.makeLocalGet(size, Type::i32))),
+ // size--
+ b.makeLocalSet(size,
+ b.makeBinary(BinaryOp::SubInt32,
+ b.makeLocalGet(size, Type::i32),
+ b.makeConst(1))),
+ // *(dst+size) = val
+ b.makeStore(1,
+ 0,
+ 1,
+ b.makeBinary(BinaryOp::AddInt32,
+ b.makeLocalGet(dst, Type::i32),
+ b.makeLocalGet(size, Type::i32)),
+ b.makeLocalGet(val, Type::i32),
+ Type::i32,
+ memory),
+ b.makeBreak("copy", nullptr)}))));
+ module->getFunction(memFillFuncName)->body = body;
+ }
+
+ void VisitTableCopy(TableCopy* curr) {
+ Fatal() << "table.copy instruction found. Memory copy lowering is not "
+ "designed to work on modules with bulk table operations";
+ }
+ void VisitTableFill(TableCopy* curr) {
+ Fatal() << "table.fill instruction found. Memory copy lowering is not "
+ "designed to work on modules with bulk table operations";
+ }
+};
+
+Pass* createMemoryCopyFillLoweringPass() {
+ return new MemoryCopyFillLowering();
+}
+
+} // namespace wasm