summaryrefslogtreecommitdiff
path: root/src/passes/MultiMemoryLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/passes/MultiMemoryLowering.cpp')
-rw-r--r--src/passes/MultiMemoryLowering.cpp108
1 files changed, 64 insertions, 44 deletions
diff --git a/src/passes/MultiMemoryLowering.cpp b/src/passes/MultiMemoryLowering.cpp
index 0b066c39b..1183a7d74 100644
--- a/src/passes/MultiMemoryLowering.cpp
+++ b/src/passes/MultiMemoryLowering.cpp
@@ -115,25 +115,67 @@ struct MultiMemoryLowering : public Pass {
replaceCurrent(builder.makeCall(funcName, {}, curr->type));
}
- template<typename T>
- Expression* getPtr(T* curr, Function* func, Index bytes) {
- auto memoryIdx = parent.memoryIdxMap.at(curr->memory);
+ Expression* addOffsetGlobal(Expression* toExpr, Name memory) {
+ auto memoryIdx = parent.memoryIdxMap.at(memory);
auto offsetGlobal = parent.getOffsetGlobal(memoryIdx);
- Expression* ptrValue;
+ Expression* returnExpr;
if (offsetGlobal) {
- ptrValue = builder.makeBinary(
+ returnExpr = builder.makeBinary(
Abstract::getBinary(parent.pointerType, Abstract::Add),
builder.makeGlobalGet(offsetGlobal, parent.pointerType),
- curr->ptr);
+ toExpr);
} else {
- ptrValue = curr->ptr;
+ returnExpr = toExpr;
}
+ return returnExpr;
+ }
+
+ Expression* makeAddGtuTrap(Expression* leftOperand,
+ Expression* rightOperand,
+ Expression* limit) {
+ Expression* gtuTrap = builder.makeIf(
+ builder.makeBinary(
+ Abstract::getBinary(parent.pointerType, Abstract::GtU),
+ builder.makeBinary(
+ Abstract::getBinary(parent.pointerType, Abstract::Add),
+ leftOperand,
+ rightOperand),
+ limit),
+ builder.makeUnreachable());
+ return gtuTrap;
+ }
+
+ Expression* makeAddGtuMemoryTrap(Expression* leftOperand,
+ Expression* rightOperand,
+ Name memory) {
+ auto memoryIdx = parent.memoryIdxMap.at(memory);
+ Name memorySizeFunc = parent.memorySizeNames[memoryIdx];
+ Expression* gtuMemoryTrap = makeAddGtuTrap(
+ leftOperand,
+ rightOperand,
+ builder.makeCall(memorySizeFunc, {}, parent.pointerType));
+ return gtuMemoryTrap;
+ }
+
+ template<typename T>
+ Expression* makePtrBoundsCheck(T* curr, Index ptrIdx, Index bytes) {
+ Expression* boundsCheck = makeAddGtuMemoryTrap(
+ builder.makeBinary(
+ // ptr + offset (ea from wasm spec) + bit width
+ Abstract::getBinary(parent.pointerType, Abstract::Add),
+ builder.makeLocalGet(ptrIdx, parent.pointerType),
+ builder.makeConstPtr(curr->offset, parent.pointerType)),
+ builder.makeConstPtr(bytes, parent.pointerType),
+ curr->memory);
+ return boundsCheck;
+ }
+ template<typename T> Expression* getPtr(T* curr, Index bytes) {
+ Expression* ptrValue = addOffsetGlobal(curr->ptr, curr->memory);
if (parent.checkBounds) {
Index ptrIdx = Builder::addVar(getFunction(), parent.pointerType);
Expression* ptrSet = builder.makeLocalSet(ptrIdx, ptrValue);
- Expression* boundsCheck =
- makeBoundsCheck(curr, ptrIdx, memoryIdx, bytes);
+ Expression* boundsCheck = makePtrBoundsCheck(curr, ptrIdx, bytes);
Expression* ptrGet = builder.makeLocalGet(ptrIdx, parent.pointerType);
return builder.makeBlock({ptrSet, boundsCheck, ptrGet});
}
@@ -141,74 +183,52 @@ struct MultiMemoryLowering : public Pass {
return ptrValue;
}
- template<typename T>
- Expression*
- makeBoundsCheck(T* curr, Index ptrIdx, Index memoryIdx, Index bytes) {
- Name memorySizeFunc = parent.memorySizeNames[memoryIdx];
- Expression* boundsCheck = builder.makeIf(
- builder.makeBinary(
- Abstract::getBinary(parent.pointerType, Abstract::GtU),
- builder.makeBinary(
- // ptr + offset (ea from wasm spec) + bit width
- // two builder Adds, we'll add the first two operands in the first
- // add and then add the third operand in the second add
- Abstract::getBinary(parent.pointerType, Abstract::Add),
- builder.makeBinary(
- Abstract::getBinary(parent.pointerType, Abstract::Add),
- builder.makeLocalGet(ptrIdx, parent.pointerType),
- builder.makeConstPtr(curr->offset, parent.pointerType)),
- builder.makeConstPtr(bytes, parent.pointerType)),
- builder.makeCall(memorySizeFunc, {}, parent.pointerType)),
- builder.makeUnreachable());
- return boundsCheck;
- }
-
template<typename T> void setMemory(T* curr) {
curr->memory = parent.combinedMemory;
}
void visitLoad(Load* curr) {
- curr->ptr = getPtr(curr, getFunction(), curr->bytes);
+ curr->ptr = getPtr(curr, curr->bytes);
setMemory(curr);
}
void visitStore(Store* curr) {
- curr->ptr = getPtr(curr, getFunction(), curr->bytes);
+ curr->ptr = getPtr(curr, curr->bytes);
setMemory(curr);
}
void visitSIMDLoad(SIMDLoad* curr) {
- curr->ptr = getPtr(curr, getFunction(), curr->getMemBytes());
+ curr->ptr = getPtr(curr, curr->getMemBytes());
setMemory(curr);
}
void visitSIMDLoadSplat(SIMDLoad* curr) {
- curr->ptr = getPtr(curr, getFunction(), curr->getMemBytes());
+ curr->ptr = getPtr(curr, curr->getMemBytes());
setMemory(curr);
}
void visitSIMDLoadExtend(SIMDLoad* curr) {
- curr->ptr = getPtr(curr, getFunction(), curr->getMemBytes());
+ curr->ptr = getPtr(curr, curr->getMemBytes());
setMemory(curr);
}
void visitSIMDLoadZero(SIMDLoad* curr) {
- curr->ptr = getPtr(curr, getFunction(), curr->getMemBytes());
+ curr->ptr = getPtr(curr, curr->getMemBytes());
setMemory(curr);
}
void visitSIMDLoadStoreLane(SIMDLoadStoreLane* curr) {
- curr->ptr = getPtr(curr, getFunction(), curr->getMemBytes());
+ curr->ptr = getPtr(curr, curr->getMemBytes());
setMemory(curr);
}
void visitAtomicRMW(AtomicRMW* curr) {
- curr->ptr = getPtr(curr, getFunction(), curr->bytes);
+ curr->ptr = getPtr(curr, curr->bytes);
setMemory(curr);
}
void visitAtomicCmpxchg(AtomicCmpxchg* curr) {
- curr->ptr = getPtr(curr, getFunction(), curr->bytes);
+ curr->ptr = getPtr(curr, curr->bytes);
setMemory(curr);
}
@@ -226,12 +246,12 @@ struct MultiMemoryLowering : public Pass {
default:
WASM_UNREACHABLE("unexpected type");
}
- curr->ptr = getPtr(curr, getFunction(), bytes);
+ curr->ptr = getPtr(curr, bytes);
setMemory(curr);
}
void visitAtomicNotify(AtomicNotify* curr) {
- curr->ptr = getPtr(curr, getFunction(), Index(4));
+ curr->ptr = getPtr(curr, Index(4));
setMemory(curr);
}
};
@@ -247,7 +267,7 @@ struct MultiMemoryLowering : public Pass {
this->wasm = module;
prepCombinedMemory();
- addOffsetGlobals();
+ makeOffsetGlobals();
adjustActiveDataSegmentOffsets();
createMemorySizeFunctions();
createMemoryGrowFunctions();
@@ -310,7 +330,7 @@ struct MultiMemoryLowering : public Pass {
combinedMemory = Names::getValidMemoryName(*wasm, "combined_memory");
}
- void addOffsetGlobals() {
+ void makeOffsetGlobals() {
auto addGlobal = [&](Name name, size_t offset) {
auto global = Builder::makeGlobal(
name,