summaryrefslogtreecommitdiff
path: root/src/passes/MultiMemoryLowering.cpp
diff options
context:
space:
mode:
authorAshley Nelson <nashley@google.com>2022-11-01 12:16:16 -0700
committerGitHub <noreply@github.com>2022-11-01 12:16:16 -0700
commit5892f493706558a95d57b3731014b400fd832d70 (patch)
tree43772fdb2ee7f59f29c0043ab29414946b9cfa43 /src/passes/MultiMemoryLowering.cpp
parent288a12645d060d8f2ec97b13b5795cd53a8a7811 (diff)
downloadbinaryen-5892f493706558a95d57b3731014b400fd832d70.tar.gz
binaryen-5892f493706558a95d57b3731014b400fd832d70.tar.bz2
binaryen-5892f493706558a95d57b3731014b400fd832d70.zip
Multi-Memories Lowering Pass (#5107)
Adds a multi-memories lowering pass that will create a single combined memory from the memories added to the module. This pass assumes that each memory is configured the same (type, shared). This pass also: - replaces existing memory.size instructions with a custom function that returns the size of each memory as if they existed independently - replaces existing memory.grow instructions with a custom function, using global offsets to track the page size of each memory so data doesn't overlap in the singled combined memory - adjusts the offsets of active data segments - adjusts the offsets of Loads/Stores
Diffstat (limited to 'src/passes/MultiMemoryLowering.cpp')
-rw-r--r--src/passes/MultiMemoryLowering.cpp421
1 files changed, 421 insertions, 0 deletions
diff --git a/src/passes/MultiMemoryLowering.cpp b/src/passes/MultiMemoryLowering.cpp
new file mode 100644
index 000000000..9e9637b36
--- /dev/null
+++ b/src/passes/MultiMemoryLowering.cpp
@@ -0,0 +1,421 @@
+/*
+ * Copyright 2022 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//
+// Condensing a module with multiple memories into a module with a single memory
+// for browsers that don’t support multiple memories.
+//
+// This pass also disables multi-memories so that the target features section in
+// the emitted module does not report the use of MultiMemories. Disabling the
+// multi-memories feature also prevents later passes from adding additional
+// memories.
+//
+// Also worth noting that we are diverging from the spec with regards to
+// handling load and store instructions. We are not trapping if the offset +
+// write size is larger than the length of the memory's data. Warning:
+// out-of-bounds loads and stores can read junk out of or corrupt other
+// memories instead of trapping.
+
+#include "ir/module-utils.h"
+#include "ir/names.h"
+#include "wasm-builder.h"
+#include <pass.h>
+#include <wasm.h>
+
+namespace wasm {
+
+struct MultiMemoryLowering : public Pass {
+ Module* wasm = nullptr;
+ // The name of the single memory that exists after this pass is run
+ Name combinedMemory;
+ // The type of the single memory
+ Type pointerType;
+ // Used to indicate the type of the single memory when creating instructions
+ // (memory.grow, memory.size) for that memory
+ Builder::MemoryInfo memoryInfo;
+ // If the combined memory is shared
+ bool isShared;
+ // The initial page size of the combined memory
+ Address totalInitialPages;
+ // The max page size of the combined memory
+ Address totalMaxPages;
+ // There is no offset for the first memory, so offsetGlobalNames will always
+ // have a size that is one less than the count of memories at the time this
+ // pass is run. Use helper getOffsetGlobal(Index) to index the vector
+ // conveniently without having to manipulate the index directly
+ std::vector<Name> offsetGlobalNames;
+ // Maps from the name of the memory to its index as seen in the
+ // module->memories vector
+ std::unordered_map<Name, Index> memoryIdxMap;
+ // A vector of the memory size function names that were created proactively
+ // for each memory
+ std::vector<Name> memorySizeNames;
+ // A vector of the memory grow functions that were created proactively for
+ // each memory
+ std::vector<Name> memoryGrowNames;
+
+ void run(Module* module) override {
+ module->features.disable(FeatureSet::MultiMemories);
+
+ // If there are no memories or 1 memory, skip this pass
+ if (module->memories.size() <= 1) {
+ return;
+ }
+
+ this->wasm = module;
+
+ prepCombinedMemory();
+ addOffsetGlobals();
+ adjustActiveDataSegmentOffsets();
+ createMemorySizeFunctions();
+ createMemoryGrowFunctions();
+ removeExistingMemories();
+ addCombinedMemory();
+
+ struct Replacer : public WalkerPass<PostWalker<Replacer>> {
+ MultiMemoryLowering& parent;
+ Builder builder;
+ Replacer(MultiMemoryLowering& parent, Module& wasm)
+ : parent(parent), builder(wasm) {}
+ // Avoid visiting the custom functions added by the parent pass
+ // MultiMemoryLowering
+ void walkFunction(Function* func) {
+ for (Name funcName : parent.memorySizeNames) {
+ if (funcName == func->name) {
+ return;
+ }
+ }
+ for (Name funcName : parent.memoryGrowNames) {
+ if (funcName == func->name) {
+ return;
+ }
+ }
+ super::walkFunction(func);
+ }
+
+ void visitMemoryGrow(MemoryGrow* curr) {
+ auto idx = parent.memoryIdxMap.at(curr->memory);
+ Name funcName = parent.memoryGrowNames[idx];
+ replaceCurrent(builder.makeCall(funcName, {curr->delta}, curr->type));
+ }
+
+ void visitMemorySize(MemorySize* curr) {
+ auto idx = parent.memoryIdxMap.at(curr->memory);
+ Name funcName = parent.memorySizeNames[idx];
+ replaceCurrent(builder.makeCall(funcName, {}, curr->type));
+ }
+
+ // TODO: Add an option to add bounds checks.
+ void visitLoad(Load* curr) {
+ auto idx = parent.memoryIdxMap.at(curr->memory);
+ auto global = parent.getOffsetGlobal(idx);
+ curr->memory = parent.combinedMemory;
+ if (!global) {
+ return;
+ }
+ curr->ptr = builder.makeBinary(
+ Abstract::getBinary(parent.pointerType, Abstract::Add),
+ builder.makeGlobalGet(global, parent.pointerType),
+ curr->ptr);
+ }
+
+ // We diverge from the spec here and are not trapping if the offset + type
+ // / 8 is larger than the length of the memory's data. Warning,
+ // out-of-bounds loads and stores can read junk out of or corrupt other
+ // memories instead of trapping
+ void visitStore(Store* curr) {
+ auto idx = parent.memoryIdxMap.at(curr->memory);
+ auto global = parent.getOffsetGlobal(idx);
+ curr->memory = parent.combinedMemory;
+ if (!global) {
+ return;
+ }
+ curr->ptr = builder.makeBinary(
+ Abstract::getBinary(parent.pointerType, Abstract::Add),
+ builder.makeGlobalGet(global, parent.pointerType),
+ curr->ptr);
+ }
+ };
+ Replacer(*this, *wasm).run(getPassRunner(), wasm);
+ }
+
+ // Returns the global name for the given idx. There is no global for the first
+ // idx, so an empty name is returned
+ Name getOffsetGlobal(Index idx) {
+ // There is no offset global for the first memory
+ if (idx == 0) {
+ return Name();
+ }
+
+ // Since there is no offset global for the first memory, we need to
+ // subtract one when indexing into the offsetGlobalName vector
+ return offsetGlobalNames[idx - 1];
+ }
+
+ // Whether the idx represents the last memory. Since there is no offset global
+ // for the first memory, the last memory is represented by the size of
+ // offsetGlobalNames
+ bool isLastMemory(Index idx) { return idx == offsetGlobalNames.size(); }
+
+ void prepCombinedMemory() {
+ pointerType = wasm->memories[0]->indexType;
+ memoryInfo = pointerType == Type::i32 ? Builder::MemoryInfo::Memory32
+ : Builder::MemoryInfo::Memory64;
+ isShared = wasm->memories[0]->shared;
+ for (auto& memory : wasm->memories) {
+ // We are assuming that each memory is configured the same as the first
+ // and assert if any of the memories does not match this configuration
+ assert(memory->shared == isShared);
+ assert(memory->indexType == pointerType);
+
+ // Calculating the total initial and max page size for the combined memory
+ // by totaling the initial and max page sizes for the memories in the
+ // module
+ totalInitialPages = totalInitialPages + memory->initial;
+ if (memory->hasMax()) {
+ totalMaxPages = totalMaxPages + memory->max;
+ }
+ }
+ // Ensuring valid initial and max page sizes that do not exceed the number
+ // of pages addressable by the pointerType
+ Address maxSize =
+ pointerType == Type::i32 ? Memory::kMaxSize32 : Memory::kMaxSize64;
+ if (totalMaxPages > maxSize || totalMaxPages == 0) {
+ totalMaxPages = Memory::kUnlimitedSize;
+ }
+ if (totalInitialPages > totalMaxPages) {
+ totalInitialPages = totalMaxPages;
+ }
+
+ // Creating the combined memory name so we can reference the combined memory
+ // in subsequent instructions before it is added to the module
+ combinedMemory = Names::getValidMemoryName(*wasm, "combined_memory");
+ }
+
+ void addOffsetGlobals() {
+ auto addGlobal = [&](Name name, size_t offset) {
+ auto global = Builder::makeGlobal(
+ name,
+ pointerType,
+ Builder(*wasm).makeConst(Literal::makeFromInt64(offset, pointerType)),
+ Builder::Mutable);
+ wasm->addGlobal(std::move(global));
+ };
+
+ size_t offsetRunningTotal = 0;
+ for (Index i = 0; i < wasm->memories.size(); i++) {
+ auto& memory = wasm->memories[i];
+ memoryIdxMap[memory->name] = i;
+ // We don't need a page offset global for the first memory as it's always
+ // 0
+ if (i != 0) {
+ Name name = Names::getValidGlobalName(
+ *wasm, memory->name.toString() + "_byte_offset");
+ offsetGlobalNames.push_back(std::move(name));
+ addGlobal(name, offsetRunningTotal * Memory::kPageSize);
+ }
+ offsetRunningTotal += memory->initial;
+ }
+ }
+
+ void adjustActiveDataSegmentOffsets() {
+ Builder builder(*wasm);
+ ModuleUtils::iterActiveDataSegments(*wasm, [&](DataSegment* dataSegment) {
+ assert(dataSegment->offset->is<Const>() &&
+ "TODO: handle non-const segment offsets");
+ auto idx = memoryIdxMap.at(dataSegment->memory);
+ dataSegment->memory = combinedMemory;
+ // No need to update the offset of data segments for the first memory
+ if (idx != 0) {
+ auto offsetGlobalName = getOffsetGlobal(idx);
+ assert(wasm->features.hasExtendedConst());
+ dataSegment->offset = builder.makeBinary(
+ Abstract::getBinary(pointerType, Abstract::Add),
+ builder.makeGlobalGet(offsetGlobalName, pointerType),
+ dataSegment->offset);
+ }
+ });
+ }
+
+ void createMemorySizeFunctions() {
+ for (Index i = 0; i < wasm->memories.size(); i++) {
+ auto function = memorySize(i, wasm->memories[i]->name);
+ memorySizeNames.push_back(function->name);
+ wasm->addFunction(std::move(function));
+ }
+ }
+
+ void createMemoryGrowFunctions() {
+ for (Index i = 0; i < wasm->memories.size(); i++) {
+ auto function = memoryGrow(i, wasm->memories[i]->name);
+ memoryGrowNames.push_back(function->name);
+ wasm->addFunction(std::move(function));
+ }
+ }
+
+ // This function replaces memory.grow instruction calls in the wasm module.
+ // Because the multiple discrete memories are lowered into a single memory,
+ // we need to adjust offsets as a particular memory receives an
+ // instruction to grow.
+ std::unique_ptr<Function> memoryGrow(Index memIdx, Name memoryName) {
+ Builder builder(*wasm);
+ Name name = memoryName.toString() + "_grow";
+ Name functionName = Names::getValidFunctionName(*wasm, name);
+ auto function = Builder::makeFunction(
+ functionName, Signature(pointerType, pointerType), {});
+ function->setLocalName(0, "page_delta");
+ auto pageSizeConst = [&]() {
+ return builder.makeConst(Literal(Memory::kPageSize));
+ };
+ auto getOffsetDelta = [&]() {
+ return builder.makeBinary(Abstract::getBinary(pointerType, Abstract::Mul),
+ builder.makeLocalGet(0, pointerType),
+ pageSizeConst());
+ };
+ auto getMoveSource = [&](Name global) {
+ return builder.makeGlobalGet(global, pointerType);
+ };
+ Expression* functionBody;
+ Index sizeLocal = -1;
+
+ Index returnLocal =
+ Builder::addVar(function.get(), "return_size", pointerType);
+ functionBody = builder.blockify(builder.makeLocalSet(
+ returnLocal, builder.makeCall(memorySizeNames[memIdx], {}, pointerType)));
+
+ if (!isLastMemory(memIdx)) {
+ sizeLocal = Builder::addVar(function.get(), "memory_size", pointerType);
+ functionBody = builder.blockify(
+ functionBody,
+ builder.makeLocalSet(
+ sizeLocal, builder.makeMemorySize(combinedMemory, memoryInfo)));
+ }
+
+ // TODO: Check the result of makeMemoryGrow for errors and return the error
+ // instead
+ functionBody = builder.blockify(
+ functionBody,
+ builder.makeDrop(builder.makeMemoryGrow(
+ builder.makeLocalGet(0, pointerType), combinedMemory, memoryInfo)));
+
+ // If we are not growing the last memory, then we need to copy data,
+ // shifting it over to accomodate the increase from page_delta
+ if (!isLastMemory(memIdx)) {
+ // This offset is the starting pt for copying
+ auto offsetGlobalName = getOffsetGlobal(memIdx + 1);
+ functionBody = builder.blockify(
+ functionBody,
+ builder.makeMemoryCopy(
+ // destination
+ builder.makeBinary(Abstract::getBinary(pointerType, Abstract::Add),
+ getMoveSource(offsetGlobalName),
+ getOffsetDelta()),
+ // source
+ getMoveSource(offsetGlobalName),
+ // size
+ builder.makeBinary(
+ Abstract::getBinary(pointerType, Abstract::Sub),
+ builder.makeBinary(Abstract::getBinary(pointerType, Abstract::Mul),
+ builder.makeLocalGet(sizeLocal, pointerType),
+ pageSizeConst()),
+ getMoveSource(offsetGlobalName)),
+ combinedMemory,
+ combinedMemory));
+ }
+
+ // Adjust the offsets of the globals impacted by the memory.grow call
+ for (Index i = memIdx; i < offsetGlobalNames.size(); i++) {
+ auto& offsetGlobalName = offsetGlobalNames[i];
+ functionBody = builder.blockify(
+ functionBody,
+ builder.makeGlobalSet(
+ offsetGlobalName,
+ builder.makeBinary(Abstract::getBinary(pointerType, Abstract::Add),
+ getMoveSource(offsetGlobalName),
+ getOffsetDelta())));
+ }
+
+ functionBody = builder.blockify(
+ functionBody, builder.makeLocalGet(returnLocal, pointerType));
+
+ function->body = functionBody;
+ return function;
+ }
+
+ // This function replaces memory.size instructions with a function that can
+ // return the size of each memory as if each was discrete and separate.
+ std::unique_ptr<Function> memorySize(Index memIdx, Name memoryName) {
+ Builder builder(*wasm);
+ Name name = memoryName.toString() + "_size";
+ Name functionName = Names::getValidFunctionName(*wasm, name);
+ auto function = Builder::makeFunction(
+ functionName, Signature(Type::none, pointerType), {});
+ Expression* functionBody;
+ auto pageSizeConst = [&]() {
+ return builder.makeConst(Literal(Memory::kPageSize));
+ };
+ auto getOffsetInPageUnits = [&](Name global) {
+ return builder.makeBinary(
+ Abstract::getBinary(pointerType, Abstract::DivU),
+ builder.makeGlobalGet(global, pointerType),
+ pageSizeConst());
+ };
+
+ // offsetGlobalNames does not keep track of a global for the offset of
+ // wasm->memories[0] because it's always 0. As a result, the below
+ // calculations that involve offsetGlobalNames are intrinsically "offset".
+ // Thus, offsetGlobalNames[0] is the offset for wasm->memories[1] and
+ // the size of wasm->memories[0].
+ if (memIdx == 0) {
+ auto offsetGlobalName = getOffsetGlobal(1);
+ functionBody = builder.blockify(
+ builder.makeReturn(getOffsetInPageUnits(offsetGlobalName)));
+ } else if (isLastMemory(memIdx)) {
+ auto offsetGlobalName = getOffsetGlobal(memIdx);
+ functionBody = builder.blockify(builder.makeReturn(
+ builder.makeBinary(Abstract::getBinary(pointerType, Abstract::Sub),
+ builder.makeMemorySize(combinedMemory, memoryInfo),
+ getOffsetInPageUnits(offsetGlobalName))));
+ } else {
+ auto offsetGlobalName = getOffsetGlobal(memIdx);
+ auto nextOffsetGlobalName = getOffsetGlobal(memIdx + 1);
+ functionBody = builder.blockify(builder.makeReturn(
+ builder.makeBinary(Abstract::getBinary(pointerType, Abstract::Sub),
+ getOffsetInPageUnits(nextOffsetGlobalName),
+ getOffsetInPageUnits(offsetGlobalName))));
+ }
+
+ function->body = functionBody;
+ return function;
+ }
+
+ void removeExistingMemories() {
+ wasm->removeMemories([&](Memory* curr) { return true; });
+ }
+
+ void addCombinedMemory() {
+ auto memory = Builder::makeMemory(combinedMemory);
+ memory->shared = isShared;
+ memory->indexType = pointerType;
+ memory->initial = totalInitialPages;
+ memory->max = totalMaxPages;
+ wasm->addMemory(std::move(memory));
+ }
+};
+
+Pass* createMultiMemoryLoweringPass() { return new MultiMemoryLowering(); }
+
+} // namespace wasm