summaryrefslogtreecommitdiff
path: root/src/passes/MultiMemoryLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/passes/MultiMemoryLowering.cpp')
-rw-r--r--src/passes/MultiMemoryLowering.cpp180
1 files changed, 110 insertions, 70 deletions
diff --git a/src/passes/MultiMemoryLowering.cpp b/src/passes/MultiMemoryLowering.cpp
index 527a15804..57be529fa 100644
--- a/src/passes/MultiMemoryLowering.cpp
+++ b/src/passes/MultiMemoryLowering.cpp
@@ -23,11 +23,14 @@
// multi-memories feature also prevents later passes from adding additional
// memories.
//
-// Also worth noting that we are diverging from the spec with regards to
-// handling load and store instructions. We are not trapping if the offset +
-// write size is larger than the length of the memory's data. Warning:
-// out-of-bounds loads and stores can read junk out of or corrupt other
-// memories instead of trapping.
+// The offset computation in function maybeMakeBoundsCheck is not precise
+// according to the spec. In the spec offsets do not overflow as
+// twos-complement, but i32.add does. Concretely, a load from address 1000 with
+// offset 0xffffffff should actually trap, as the combined number is greater
+// than 32 bits. But with an add, 1000 + 0xffffffff = 999 due to overflow, which
+// would not trap. In theory we could compute like the spec, by expanding the
+// i32s to i64s and adding there (where we won't overflow), but we don't have
+// i128s to handle i64 overflow.
#include "ir/module-utils.h"
#include "ir/names.h"
@@ -67,6 +70,103 @@ struct MultiMemoryLowering : public Pass {
// each memory
std::vector<Name> memoryGrowNames;
+ bool checkBounds = false;
+
+ MultiMemoryLowering(bool checkBounds) : checkBounds(checkBounds) {}
+
+ struct Replacer : public WalkerPass<PostWalker<Replacer>> {
+ MultiMemoryLowering& parent;
+ Builder builder;
+ Replacer(MultiMemoryLowering& parent, Module& wasm)
+ : parent(parent), builder(wasm) {}
+ // Avoid visiting the custom functions added by the parent pass
+ // MultiMemoryLowering
+ void walkFunction(Function* func) {
+ for (Name funcName : parent.memorySizeNames) {
+ if (funcName == func->name) {
+ return;
+ }
+ }
+ for (Name funcName : parent.memoryGrowNames) {
+ if (funcName == func->name) {
+ return;
+ }
+ }
+ super::walkFunction(func);
+ }
+
+ void visitMemoryGrow(MemoryGrow* curr) {
+ auto idx = parent.memoryIdxMap.at(curr->memory);
+ Name funcName = parent.memoryGrowNames[idx];
+ replaceCurrent(builder.makeCall(funcName, {curr->delta}, curr->type));
+ }
+
+ void visitMemorySize(MemorySize* curr) {
+ auto idx = parent.memoryIdxMap.at(curr->memory);
+ Name funcName = parent.memorySizeNames[idx];
+ replaceCurrent(builder.makeCall(funcName, {}, curr->type));
+ }
+
+ template<typename T> Expression* getPtr(T* curr, Function* func) {
+ auto memoryIdx = parent.memoryIdxMap.at(curr->memory);
+ auto offsetGlobal = parent.getOffsetGlobal(memoryIdx);
+ Expression* ptrValue;
+ if (offsetGlobal) {
+ ptrValue = builder.makeBinary(
+ Abstract::getBinary(parent.pointerType, Abstract::Add),
+ builder.makeGlobalGet(offsetGlobal, parent.pointerType),
+ curr->ptr);
+ } else {
+ ptrValue = curr->ptr;
+ }
+
+ if (parent.checkBounds) {
+ Index ptrIdx = Builder::addVar(getFunction(), parent.pointerType);
+ Expression* ptrSet = builder.makeLocalSet(ptrIdx, ptrValue);
+ Expression* boundsCheck = makeBoundsCheck(curr, ptrIdx, memoryIdx);
+ Expression* ptrGet = builder.makeLocalGet(ptrIdx, parent.pointerType);
+ return builder.makeBlock({ptrSet, boundsCheck, ptrGet});
+ }
+
+ return ptrValue;
+ }
+
+ template<typename T>
+ Expression* makeBoundsCheck(T* curr, Index ptrIdx, Index memoryIdx) {
+ Name memorySizeFunc = parent.memorySizeNames[memoryIdx];
+ Expression* boundsCheck = builder.makeIf(
+ builder.makeBinary(
+ Abstract::getBinary(parent.pointerType, Abstract::GtU),
+ builder.makeBinary(
+ // ptr + offset (ea from wasm spec) + bit width
+ // two builder Adds, we'll add the first two operands in the first
+ // add and then add the third operand in the second add
+ Abstract::getBinary(parent.pointerType, Abstract::Add),
+ builder.makeBinary(
+ Abstract::getBinary(parent.pointerType, Abstract::Add),
+ builder.makeLocalGet(ptrIdx, parent.pointerType),
+ builder.makeConstPtr(curr->offset, parent.pointerType)),
+ builder.makeConstPtr(curr->bytes, parent.pointerType)),
+ builder.makeCall(memorySizeFunc, {}, parent.pointerType)),
+ builder.makeUnreachable());
+ return boundsCheck;
+ }
+
+ template<typename T> void setMemory(T* curr) {
+ curr->memory = parent.combinedMemory;
+ }
+
+ void visitLoad(Load* curr) {
+ curr->ptr = getPtr(curr, getFunction());
+ setMemory(curr);
+ }
+
+ void visitStore(Store* curr) {
+ curr->ptr = getPtr(curr, getFunction());
+ setMemory(curr);
+ }
+ };
+
void run(Module* module) override {
module->features.disable(FeatureSet::MultiMemories);
@@ -85,70 +185,6 @@ struct MultiMemoryLowering : public Pass {
removeExistingMemories();
addCombinedMemory();
- struct Replacer : public WalkerPass<PostWalker<Replacer>> {
- MultiMemoryLowering& parent;
- Builder builder;
- Replacer(MultiMemoryLowering& parent, Module& wasm)
- : parent(parent), builder(wasm) {}
- // Avoid visiting the custom functions added by the parent pass
- // MultiMemoryLowering
- void walkFunction(Function* func) {
- for (Name funcName : parent.memorySizeNames) {
- if (funcName == func->name) {
- return;
- }
- }
- for (Name funcName : parent.memoryGrowNames) {
- if (funcName == func->name) {
- return;
- }
- }
- super::walkFunction(func);
- }
-
- void visitMemoryGrow(MemoryGrow* curr) {
- auto idx = parent.memoryIdxMap.at(curr->memory);
- Name funcName = parent.memoryGrowNames[idx];
- replaceCurrent(builder.makeCall(funcName, {curr->delta}, curr->type));
- }
-
- void visitMemorySize(MemorySize* curr) {
- auto idx = parent.memoryIdxMap.at(curr->memory);
- Name funcName = parent.memorySizeNames[idx];
- replaceCurrent(builder.makeCall(funcName, {}, curr->type));
- }
-
- // TODO: Add an option to add bounds checks.
- void visitLoad(Load* curr) {
- auto idx = parent.memoryIdxMap.at(curr->memory);
- auto global = parent.getOffsetGlobal(idx);
- curr->memory = parent.combinedMemory;
- if (!global) {
- return;
- }
- curr->ptr = builder.makeBinary(
- Abstract::getBinary(parent.pointerType, Abstract::Add),
- builder.makeGlobalGet(global, parent.pointerType),
- curr->ptr);
- }
-
- // We diverge from the spec here and are not trapping if the offset + type
- // / 8 is larger than the length of the memory's data. Warning,
- // out-of-bounds loads and stores can read junk out of or corrupt other
- // memories instead of trapping
- void visitStore(Store* curr) {
- auto idx = parent.memoryIdxMap.at(curr->memory);
- auto global = parent.getOffsetGlobal(idx);
- curr->memory = parent.combinedMemory;
- if (!global) {
- return;
- }
- curr->ptr = builder.makeBinary(
- Abstract::getBinary(parent.pointerType, Abstract::Add),
- builder.makeGlobalGet(global, parent.pointerType),
- curr->ptr);
- }
- };
Replacer(*this, *wasm).run(getPassRunner(), wasm);
}
@@ -421,6 +457,10 @@ struct MultiMemoryLowering : public Pass {
}
};
-Pass* createMultiMemoryLoweringPass() { return new MultiMemoryLowering(); }
+Pass* createMultiMemoryLoweringPass() { return new MultiMemoryLowering(false); }
+
+Pass* createMultiMemoryLoweringWithBoundsChecksPass() {
+ return new MultiMemoryLowering(true);
+}
} // namespace wasm