diff options
Diffstat (limited to 'src/passes/MultiMemoryLowering.cpp')
-rw-r--r-- | src/passes/MultiMemoryLowering.cpp | 180 |
1 files changed, 110 insertions, 70 deletions
diff --git a/src/passes/MultiMemoryLowering.cpp b/src/passes/MultiMemoryLowering.cpp index 527a15804..57be529fa 100644 --- a/src/passes/MultiMemoryLowering.cpp +++ b/src/passes/MultiMemoryLowering.cpp @@ -23,11 +23,14 @@ // multi-memories feature also prevents later passes from adding additional // memories. // -// Also worth noting that we are diverging from the spec with regards to -// handling load and store instructions. We are not trapping if the offset + -// write size is larger than the length of the memory's data. Warning: -// out-of-bounds loads and stores can read junk out of or corrupt other -// memories instead of trapping. +// The offset computation in function maybeMakeBoundsCheck is not precise +// according to the spec. In the spec offsets do not overflow as +// twos-complement, but i32.add does. Concretely, a load from address 1000 with +// offset 0xffffffff should actually trap, as the combined number is greater +// than 32 bits. But with an add, 1000 + 0xffffffff = 999 due to overflow, which +// would not trap. In theory we could compute like the spec, by expanding the +// i32s to i64s and adding there (where we won't overflow), but we don't have +// i128s to handle i64 overflow. #include "ir/module-utils.h" #include "ir/names.h" @@ -67,6 +70,103 @@ struct MultiMemoryLowering : public Pass { // each memory std::vector<Name> memoryGrowNames; + bool checkBounds = false; + + MultiMemoryLowering(bool checkBounds) : checkBounds(checkBounds) {} + + struct Replacer : public WalkerPass<PostWalker<Replacer>> { + MultiMemoryLowering& parent; + Builder builder; + Replacer(MultiMemoryLowering& parent, Module& wasm) + : parent(parent), builder(wasm) {} + // Avoid visiting the custom functions added by the parent pass + // MultiMemoryLowering + void walkFunction(Function* func) { + for (Name funcName : parent.memorySizeNames) { + if (funcName == func->name) { + return; + } + } + for (Name funcName : parent.memoryGrowNames) { + if (funcName == func->name) { + return; + } + } + super::walkFunction(func); + } + + void visitMemoryGrow(MemoryGrow* curr) { + auto idx = parent.memoryIdxMap.at(curr->memory); + Name funcName = parent.memoryGrowNames[idx]; + replaceCurrent(builder.makeCall(funcName, {curr->delta}, curr->type)); + } + + void visitMemorySize(MemorySize* curr) { + auto idx = parent.memoryIdxMap.at(curr->memory); + Name funcName = parent.memorySizeNames[idx]; + replaceCurrent(builder.makeCall(funcName, {}, curr->type)); + } + + template<typename T> Expression* getPtr(T* curr, Function* func) { + auto memoryIdx = parent.memoryIdxMap.at(curr->memory); + auto offsetGlobal = parent.getOffsetGlobal(memoryIdx); + Expression* ptrValue; + if (offsetGlobal) { + ptrValue = builder.makeBinary( + Abstract::getBinary(parent.pointerType, Abstract::Add), + builder.makeGlobalGet(offsetGlobal, parent.pointerType), + curr->ptr); + } else { + ptrValue = curr->ptr; + } + + if (parent.checkBounds) { + Index ptrIdx = Builder::addVar(getFunction(), parent.pointerType); + Expression* ptrSet = builder.makeLocalSet(ptrIdx, ptrValue); + Expression* boundsCheck = makeBoundsCheck(curr, ptrIdx, memoryIdx); + Expression* ptrGet = builder.makeLocalGet(ptrIdx, parent.pointerType); + return builder.makeBlock({ptrSet, boundsCheck, ptrGet}); + } + + return ptrValue; + } + + template<typename T> + Expression* makeBoundsCheck(T* curr, Index ptrIdx, Index memoryIdx) { + Name memorySizeFunc = parent.memorySizeNames[memoryIdx]; + Expression* boundsCheck = builder.makeIf( + builder.makeBinary( + Abstract::getBinary(parent.pointerType, Abstract::GtU), + builder.makeBinary( + // ptr + offset (ea from wasm spec) + bit width + // two builder Adds, we'll add the first two operands in the first + // add and then add the third operand in the second add + Abstract::getBinary(parent.pointerType, Abstract::Add), + builder.makeBinary( + Abstract::getBinary(parent.pointerType, Abstract::Add), + builder.makeLocalGet(ptrIdx, parent.pointerType), + builder.makeConstPtr(curr->offset, parent.pointerType)), + builder.makeConstPtr(curr->bytes, parent.pointerType)), + builder.makeCall(memorySizeFunc, {}, parent.pointerType)), + builder.makeUnreachable()); + return boundsCheck; + } + + template<typename T> void setMemory(T* curr) { + curr->memory = parent.combinedMemory; + } + + void visitLoad(Load* curr) { + curr->ptr = getPtr(curr, getFunction()); + setMemory(curr); + } + + void visitStore(Store* curr) { + curr->ptr = getPtr(curr, getFunction()); + setMemory(curr); + } + }; + void run(Module* module) override { module->features.disable(FeatureSet::MultiMemories); @@ -85,70 +185,6 @@ struct MultiMemoryLowering : public Pass { removeExistingMemories(); addCombinedMemory(); - struct Replacer : public WalkerPass<PostWalker<Replacer>> { - MultiMemoryLowering& parent; - Builder builder; - Replacer(MultiMemoryLowering& parent, Module& wasm) - : parent(parent), builder(wasm) {} - // Avoid visiting the custom functions added by the parent pass - // MultiMemoryLowering - void walkFunction(Function* func) { - for (Name funcName : parent.memorySizeNames) { - if (funcName == func->name) { - return; - } - } - for (Name funcName : parent.memoryGrowNames) { - if (funcName == func->name) { - return; - } - } - super::walkFunction(func); - } - - void visitMemoryGrow(MemoryGrow* curr) { - auto idx = parent.memoryIdxMap.at(curr->memory); - Name funcName = parent.memoryGrowNames[idx]; - replaceCurrent(builder.makeCall(funcName, {curr->delta}, curr->type)); - } - - void visitMemorySize(MemorySize* curr) { - auto idx = parent.memoryIdxMap.at(curr->memory); - Name funcName = parent.memorySizeNames[idx]; - replaceCurrent(builder.makeCall(funcName, {}, curr->type)); - } - - // TODO: Add an option to add bounds checks. - void visitLoad(Load* curr) { - auto idx = parent.memoryIdxMap.at(curr->memory); - auto global = parent.getOffsetGlobal(idx); - curr->memory = parent.combinedMemory; - if (!global) { - return; - } - curr->ptr = builder.makeBinary( - Abstract::getBinary(parent.pointerType, Abstract::Add), - builder.makeGlobalGet(global, parent.pointerType), - curr->ptr); - } - - // We diverge from the spec here and are not trapping if the offset + type - // / 8 is larger than the length of the memory's data. Warning, - // out-of-bounds loads and stores can read junk out of or corrupt other - // memories instead of trapping - void visitStore(Store* curr) { - auto idx = parent.memoryIdxMap.at(curr->memory); - auto global = parent.getOffsetGlobal(idx); - curr->memory = parent.combinedMemory; - if (!global) { - return; - } - curr->ptr = builder.makeBinary( - Abstract::getBinary(parent.pointerType, Abstract::Add), - builder.makeGlobalGet(global, parent.pointerType), - curr->ptr); - } - }; Replacer(*this, *wasm).run(getPassRunner(), wasm); } @@ -421,6 +457,10 @@ struct MultiMemoryLowering : public Pass { } }; -Pass* createMultiMemoryLoweringPass() { return new MultiMemoryLowering(); } +Pass* createMultiMemoryLoweringPass() { return new MultiMemoryLowering(false); } + +Pass* createMultiMemoryLoweringWithBoundsChecksPass() { + return new MultiMemoryLowering(true); +} } // namespace wasm |