diff options
author | Ashley Nelson <nashley@google.com> | 2022-11-01 12:16:16 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-11-01 12:16:16 -0700 |
commit | 5892f493706558a95d57b3731014b400fd832d70 (patch) | |
tree | 43772fdb2ee7f59f29c0043ab29414946b9cfa43 /src/passes/MultiMemoryLowering.cpp | |
parent | 288a12645d060d8f2ec97b13b5795cd53a8a7811 (diff) | |
download | binaryen-5892f493706558a95d57b3731014b400fd832d70.tar.gz binaryen-5892f493706558a95d57b3731014b400fd832d70.tar.bz2 binaryen-5892f493706558a95d57b3731014b400fd832d70.zip |
Multi-Memories Lowering Pass (#5107)
Adds a multi-memories lowering pass that will create a single combined memory from the memories added to the module. This pass assumes that each memory is configured the same (type, shared).
This pass also:
- replaces existing memory.size instructions with a custom function that returns the size of each memory as if they existed independently
- replaces existing memory.grow instructions with a custom function, using global offsets to track the page size of each memory so data doesn't overlap in the singled combined memory
- adjusts the offsets of active data segments
- adjusts the offsets of Loads/Stores
Diffstat (limited to 'src/passes/MultiMemoryLowering.cpp')
-rw-r--r-- | src/passes/MultiMemoryLowering.cpp | 421 |
1 files changed, 421 insertions, 0 deletions
diff --git a/src/passes/MultiMemoryLowering.cpp b/src/passes/MultiMemoryLowering.cpp new file mode 100644 index 000000000..9e9637b36 --- /dev/null +++ b/src/passes/MultiMemoryLowering.cpp @@ -0,0 +1,421 @@ +/* + * Copyright 2022 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +// Condensing a module with multiple memories into a module with a single memory +// for browsers that don’t support multiple memories. +// +// This pass also disables multi-memories so that the target features section in +// the emitted module does not report the use of MultiMemories. Disabling the +// multi-memories feature also prevents later passes from adding additional +// memories. +// +// Also worth noting that we are diverging from the spec with regards to +// handling load and store instructions. We are not trapping if the offset + +// write size is larger than the length of the memory's data. Warning: +// out-of-bounds loads and stores can read junk out of or corrupt other +// memories instead of trapping. + +#include "ir/module-utils.h" +#include "ir/names.h" +#include "wasm-builder.h" +#include <pass.h> +#include <wasm.h> + +namespace wasm { + +struct MultiMemoryLowering : public Pass { + Module* wasm = nullptr; + // The name of the single memory that exists after this pass is run + Name combinedMemory; + // The type of the single memory + Type pointerType; + // Used to indicate the type of the single memory when creating instructions + // (memory.grow, memory.size) for that memory + Builder::MemoryInfo memoryInfo; + // If the combined memory is shared + bool isShared; + // The initial page size of the combined memory + Address totalInitialPages; + // The max page size of the combined memory + Address totalMaxPages; + // There is no offset for the first memory, so offsetGlobalNames will always + // have a size that is one less than the count of memories at the time this + // pass is run. Use helper getOffsetGlobal(Index) to index the vector + // conveniently without having to manipulate the index directly + std::vector<Name> offsetGlobalNames; + // Maps from the name of the memory to its index as seen in the + // module->memories vector + std::unordered_map<Name, Index> memoryIdxMap; + // A vector of the memory size function names that were created proactively + // for each memory + std::vector<Name> memorySizeNames; + // A vector of the memory grow functions that were created proactively for + // each memory + std::vector<Name> memoryGrowNames; + + void run(Module* module) override { + module->features.disable(FeatureSet::MultiMemories); + + // If there are no memories or 1 memory, skip this pass + if (module->memories.size() <= 1) { + return; + } + + this->wasm = module; + + prepCombinedMemory(); + addOffsetGlobals(); + adjustActiveDataSegmentOffsets(); + createMemorySizeFunctions(); + createMemoryGrowFunctions(); + removeExistingMemories(); + addCombinedMemory(); + + struct Replacer : public WalkerPass<PostWalker<Replacer>> { + MultiMemoryLowering& parent; + Builder builder; + Replacer(MultiMemoryLowering& parent, Module& wasm) + : parent(parent), builder(wasm) {} + // Avoid visiting the custom functions added by the parent pass + // MultiMemoryLowering + void walkFunction(Function* func) { + for (Name funcName : parent.memorySizeNames) { + if (funcName == func->name) { + return; + } + } + for (Name funcName : parent.memoryGrowNames) { + if (funcName == func->name) { + return; + } + } + super::walkFunction(func); + } + + void visitMemoryGrow(MemoryGrow* curr) { + auto idx = parent.memoryIdxMap.at(curr->memory); + Name funcName = parent.memoryGrowNames[idx]; + replaceCurrent(builder.makeCall(funcName, {curr->delta}, curr->type)); + } + + void visitMemorySize(MemorySize* curr) { + auto idx = parent.memoryIdxMap.at(curr->memory); + Name funcName = parent.memorySizeNames[idx]; + replaceCurrent(builder.makeCall(funcName, {}, curr->type)); + } + + // TODO: Add an option to add bounds checks. + void visitLoad(Load* curr) { + auto idx = parent.memoryIdxMap.at(curr->memory); + auto global = parent.getOffsetGlobal(idx); + curr->memory = parent.combinedMemory; + if (!global) { + return; + } + curr->ptr = builder.makeBinary( + Abstract::getBinary(parent.pointerType, Abstract::Add), + builder.makeGlobalGet(global, parent.pointerType), + curr->ptr); + } + + // We diverge from the spec here and are not trapping if the offset + type + // / 8 is larger than the length of the memory's data. Warning, + // out-of-bounds loads and stores can read junk out of or corrupt other + // memories instead of trapping + void visitStore(Store* curr) { + auto idx = parent.memoryIdxMap.at(curr->memory); + auto global = parent.getOffsetGlobal(idx); + curr->memory = parent.combinedMemory; + if (!global) { + return; + } + curr->ptr = builder.makeBinary( + Abstract::getBinary(parent.pointerType, Abstract::Add), + builder.makeGlobalGet(global, parent.pointerType), + curr->ptr); + } + }; + Replacer(*this, *wasm).run(getPassRunner(), wasm); + } + + // Returns the global name for the given idx. There is no global for the first + // idx, so an empty name is returned + Name getOffsetGlobal(Index idx) { + // There is no offset global for the first memory + if (idx == 0) { + return Name(); + } + + // Since there is no offset global for the first memory, we need to + // subtract one when indexing into the offsetGlobalName vector + return offsetGlobalNames[idx - 1]; + } + + // Whether the idx represents the last memory. Since there is no offset global + // for the first memory, the last memory is represented by the size of + // offsetGlobalNames + bool isLastMemory(Index idx) { return idx == offsetGlobalNames.size(); } + + void prepCombinedMemory() { + pointerType = wasm->memories[0]->indexType; + memoryInfo = pointerType == Type::i32 ? Builder::MemoryInfo::Memory32 + : Builder::MemoryInfo::Memory64; + isShared = wasm->memories[0]->shared; + for (auto& memory : wasm->memories) { + // We are assuming that each memory is configured the same as the first + // and assert if any of the memories does not match this configuration + assert(memory->shared == isShared); + assert(memory->indexType == pointerType); + + // Calculating the total initial and max page size for the combined memory + // by totaling the initial and max page sizes for the memories in the + // module + totalInitialPages = totalInitialPages + memory->initial; + if (memory->hasMax()) { + totalMaxPages = totalMaxPages + memory->max; + } + } + // Ensuring valid initial and max page sizes that do not exceed the number + // of pages addressable by the pointerType + Address maxSize = + pointerType == Type::i32 ? Memory::kMaxSize32 : Memory::kMaxSize64; + if (totalMaxPages > maxSize || totalMaxPages == 0) { + totalMaxPages = Memory::kUnlimitedSize; + } + if (totalInitialPages > totalMaxPages) { + totalInitialPages = totalMaxPages; + } + + // Creating the combined memory name so we can reference the combined memory + // in subsequent instructions before it is added to the module + combinedMemory = Names::getValidMemoryName(*wasm, "combined_memory"); + } + + void addOffsetGlobals() { + auto addGlobal = [&](Name name, size_t offset) { + auto global = Builder::makeGlobal( + name, + pointerType, + Builder(*wasm).makeConst(Literal::makeFromInt64(offset, pointerType)), + Builder::Mutable); + wasm->addGlobal(std::move(global)); + }; + + size_t offsetRunningTotal = 0; + for (Index i = 0; i < wasm->memories.size(); i++) { + auto& memory = wasm->memories[i]; + memoryIdxMap[memory->name] = i; + // We don't need a page offset global for the first memory as it's always + // 0 + if (i != 0) { + Name name = Names::getValidGlobalName( + *wasm, memory->name.toString() + "_byte_offset"); + offsetGlobalNames.push_back(std::move(name)); + addGlobal(name, offsetRunningTotal * Memory::kPageSize); + } + offsetRunningTotal += memory->initial; + } + } + + void adjustActiveDataSegmentOffsets() { + Builder builder(*wasm); + ModuleUtils::iterActiveDataSegments(*wasm, [&](DataSegment* dataSegment) { + assert(dataSegment->offset->is<Const>() && + "TODO: handle non-const segment offsets"); + auto idx = memoryIdxMap.at(dataSegment->memory); + dataSegment->memory = combinedMemory; + // No need to update the offset of data segments for the first memory + if (idx != 0) { + auto offsetGlobalName = getOffsetGlobal(idx); + assert(wasm->features.hasExtendedConst()); + dataSegment->offset = builder.makeBinary( + Abstract::getBinary(pointerType, Abstract::Add), + builder.makeGlobalGet(offsetGlobalName, pointerType), + dataSegment->offset); + } + }); + } + + void createMemorySizeFunctions() { + for (Index i = 0; i < wasm->memories.size(); i++) { + auto function = memorySize(i, wasm->memories[i]->name); + memorySizeNames.push_back(function->name); + wasm->addFunction(std::move(function)); + } + } + + void createMemoryGrowFunctions() { + for (Index i = 0; i < wasm->memories.size(); i++) { + auto function = memoryGrow(i, wasm->memories[i]->name); + memoryGrowNames.push_back(function->name); + wasm->addFunction(std::move(function)); + } + } + + // This function replaces memory.grow instruction calls in the wasm module. + // Because the multiple discrete memories are lowered into a single memory, + // we need to adjust offsets as a particular memory receives an + // instruction to grow. + std::unique_ptr<Function> memoryGrow(Index memIdx, Name memoryName) { + Builder builder(*wasm); + Name name = memoryName.toString() + "_grow"; + Name functionName = Names::getValidFunctionName(*wasm, name); + auto function = Builder::makeFunction( + functionName, Signature(pointerType, pointerType), {}); + function->setLocalName(0, "page_delta"); + auto pageSizeConst = [&]() { + return builder.makeConst(Literal(Memory::kPageSize)); + }; + auto getOffsetDelta = [&]() { + return builder.makeBinary(Abstract::getBinary(pointerType, Abstract::Mul), + builder.makeLocalGet(0, pointerType), + pageSizeConst()); + }; + auto getMoveSource = [&](Name global) { + return builder.makeGlobalGet(global, pointerType); + }; + Expression* functionBody; + Index sizeLocal = -1; + + Index returnLocal = + Builder::addVar(function.get(), "return_size", pointerType); + functionBody = builder.blockify(builder.makeLocalSet( + returnLocal, builder.makeCall(memorySizeNames[memIdx], {}, pointerType))); + + if (!isLastMemory(memIdx)) { + sizeLocal = Builder::addVar(function.get(), "memory_size", pointerType); + functionBody = builder.blockify( + functionBody, + builder.makeLocalSet( + sizeLocal, builder.makeMemorySize(combinedMemory, memoryInfo))); + } + + // TODO: Check the result of makeMemoryGrow for errors and return the error + // instead + functionBody = builder.blockify( + functionBody, + builder.makeDrop(builder.makeMemoryGrow( + builder.makeLocalGet(0, pointerType), combinedMemory, memoryInfo))); + + // If we are not growing the last memory, then we need to copy data, + // shifting it over to accomodate the increase from page_delta + if (!isLastMemory(memIdx)) { + // This offset is the starting pt for copying + auto offsetGlobalName = getOffsetGlobal(memIdx + 1); + functionBody = builder.blockify( + functionBody, + builder.makeMemoryCopy( + // destination + builder.makeBinary(Abstract::getBinary(pointerType, Abstract::Add), + getMoveSource(offsetGlobalName), + getOffsetDelta()), + // source + getMoveSource(offsetGlobalName), + // size + builder.makeBinary( + Abstract::getBinary(pointerType, Abstract::Sub), + builder.makeBinary(Abstract::getBinary(pointerType, Abstract::Mul), + builder.makeLocalGet(sizeLocal, pointerType), + pageSizeConst()), + getMoveSource(offsetGlobalName)), + combinedMemory, + combinedMemory)); + } + + // Adjust the offsets of the globals impacted by the memory.grow call + for (Index i = memIdx; i < offsetGlobalNames.size(); i++) { + auto& offsetGlobalName = offsetGlobalNames[i]; + functionBody = builder.blockify( + functionBody, + builder.makeGlobalSet( + offsetGlobalName, + builder.makeBinary(Abstract::getBinary(pointerType, Abstract::Add), + getMoveSource(offsetGlobalName), + getOffsetDelta()))); + } + + functionBody = builder.blockify( + functionBody, builder.makeLocalGet(returnLocal, pointerType)); + + function->body = functionBody; + return function; + } + + // This function replaces memory.size instructions with a function that can + // return the size of each memory as if each was discrete and separate. + std::unique_ptr<Function> memorySize(Index memIdx, Name memoryName) { + Builder builder(*wasm); + Name name = memoryName.toString() + "_size"; + Name functionName = Names::getValidFunctionName(*wasm, name); + auto function = Builder::makeFunction( + functionName, Signature(Type::none, pointerType), {}); + Expression* functionBody; + auto pageSizeConst = [&]() { + return builder.makeConst(Literal(Memory::kPageSize)); + }; + auto getOffsetInPageUnits = [&](Name global) { + return builder.makeBinary( + Abstract::getBinary(pointerType, Abstract::DivU), + builder.makeGlobalGet(global, pointerType), + pageSizeConst()); + }; + + // offsetGlobalNames does not keep track of a global for the offset of + // wasm->memories[0] because it's always 0. As a result, the below + // calculations that involve offsetGlobalNames are intrinsically "offset". + // Thus, offsetGlobalNames[0] is the offset for wasm->memories[1] and + // the size of wasm->memories[0]. + if (memIdx == 0) { + auto offsetGlobalName = getOffsetGlobal(1); + functionBody = builder.blockify( + builder.makeReturn(getOffsetInPageUnits(offsetGlobalName))); + } else if (isLastMemory(memIdx)) { + auto offsetGlobalName = getOffsetGlobal(memIdx); + functionBody = builder.blockify(builder.makeReturn( + builder.makeBinary(Abstract::getBinary(pointerType, Abstract::Sub), + builder.makeMemorySize(combinedMemory, memoryInfo), + getOffsetInPageUnits(offsetGlobalName)))); + } else { + auto offsetGlobalName = getOffsetGlobal(memIdx); + auto nextOffsetGlobalName = getOffsetGlobal(memIdx + 1); + functionBody = builder.blockify(builder.makeReturn( + builder.makeBinary(Abstract::getBinary(pointerType, Abstract::Sub), + getOffsetInPageUnits(nextOffsetGlobalName), + getOffsetInPageUnits(offsetGlobalName)))); + } + + function->body = functionBody; + return function; + } + + void removeExistingMemories() { + wasm->removeMemories([&](Memory* curr) { return true; }); + } + + void addCombinedMemory() { + auto memory = Builder::makeMemory(combinedMemory); + memory->shared = isShared; + memory->indexType = pointerType; + memory->initial = totalInitialPages; + memory->max = totalMaxPages; + wasm->addMemory(std::move(memory)); + } +}; + +Pass* createMultiMemoryLoweringPass() { return new MultiMemoryLowering(); } + +} // namespace wasm |