diff options
author | Ashley Nelson <nashley@google.com> | 2022-12-15 09:59:41 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-12-15 09:59:41 -0800 |
commit | 42896405d13d795c6c64ca5fad978bff025ff33a (patch) | |
tree | 72df75f86fa392a4533687fbe62d43b995892aa1 /src | |
parent | 4755130ed0013f32d884d005ef6ec379c94ef25e (diff) | |
download | binaryen-42896405d13d795c6c64ca5fad978bff025ff33a.tar.gz binaryen-42896405d13d795c6c64ca5fad978bff025ff33a.tar.bz2 binaryen-42896405d13d795c6c64ca5fad978bff025ff33a.zip |
Refactor Multi-Memory Lowering pass to support additional instructions (#5352)
This PR breaks up the two main functions involved in each memory instruction (getPtr, makeBoundsCheck) into several smaller functions. This is a first step in adding support for bounds checks in the instructions memory: init, copy, and fill. Each of these instructions is a more unique case than the other memory instructions that have already been added to the Multi-Memory Lowering pass.
Diffstat (limited to 'src')
-rw-r--r-- | src/passes/MultiMemoryLowering.cpp | 108 |
1 files changed, 64 insertions, 44 deletions
diff --git a/src/passes/MultiMemoryLowering.cpp b/src/passes/MultiMemoryLowering.cpp index 0b066c39b..1183a7d74 100644 --- a/src/passes/MultiMemoryLowering.cpp +++ b/src/passes/MultiMemoryLowering.cpp @@ -115,25 +115,67 @@ struct MultiMemoryLowering : public Pass { replaceCurrent(builder.makeCall(funcName, {}, curr->type)); } - template<typename T> - Expression* getPtr(T* curr, Function* func, Index bytes) { - auto memoryIdx = parent.memoryIdxMap.at(curr->memory); + Expression* addOffsetGlobal(Expression* toExpr, Name memory) { + auto memoryIdx = parent.memoryIdxMap.at(memory); auto offsetGlobal = parent.getOffsetGlobal(memoryIdx); - Expression* ptrValue; + Expression* returnExpr; if (offsetGlobal) { - ptrValue = builder.makeBinary( + returnExpr = builder.makeBinary( Abstract::getBinary(parent.pointerType, Abstract::Add), builder.makeGlobalGet(offsetGlobal, parent.pointerType), - curr->ptr); + toExpr); } else { - ptrValue = curr->ptr; + returnExpr = toExpr; } + return returnExpr; + } + + Expression* makeAddGtuTrap(Expression* leftOperand, + Expression* rightOperand, + Expression* limit) { + Expression* gtuTrap = builder.makeIf( + builder.makeBinary( + Abstract::getBinary(parent.pointerType, Abstract::GtU), + builder.makeBinary( + Abstract::getBinary(parent.pointerType, Abstract::Add), + leftOperand, + rightOperand), + limit), + builder.makeUnreachable()); + return gtuTrap; + } + + Expression* makeAddGtuMemoryTrap(Expression* leftOperand, + Expression* rightOperand, + Name memory) { + auto memoryIdx = parent.memoryIdxMap.at(memory); + Name memorySizeFunc = parent.memorySizeNames[memoryIdx]; + Expression* gtuMemoryTrap = makeAddGtuTrap( + leftOperand, + rightOperand, + builder.makeCall(memorySizeFunc, {}, parent.pointerType)); + return gtuMemoryTrap; + } + + template<typename T> + Expression* makePtrBoundsCheck(T* curr, Index ptrIdx, Index bytes) { + Expression* boundsCheck = makeAddGtuMemoryTrap( + builder.makeBinary( + // ptr + offset (ea from wasm spec) + bit width + Abstract::getBinary(parent.pointerType, Abstract::Add), + builder.makeLocalGet(ptrIdx, parent.pointerType), + builder.makeConstPtr(curr->offset, parent.pointerType)), + builder.makeConstPtr(bytes, parent.pointerType), + curr->memory); + return boundsCheck; + } + template<typename T> Expression* getPtr(T* curr, Index bytes) { + Expression* ptrValue = addOffsetGlobal(curr->ptr, curr->memory); if (parent.checkBounds) { Index ptrIdx = Builder::addVar(getFunction(), parent.pointerType); Expression* ptrSet = builder.makeLocalSet(ptrIdx, ptrValue); - Expression* boundsCheck = - makeBoundsCheck(curr, ptrIdx, memoryIdx, bytes); + Expression* boundsCheck = makePtrBoundsCheck(curr, ptrIdx, bytes); Expression* ptrGet = builder.makeLocalGet(ptrIdx, parent.pointerType); return builder.makeBlock({ptrSet, boundsCheck, ptrGet}); } @@ -141,74 +183,52 @@ struct MultiMemoryLowering : public Pass { return ptrValue; } - template<typename T> - Expression* - makeBoundsCheck(T* curr, Index ptrIdx, Index memoryIdx, Index bytes) { - Name memorySizeFunc = parent.memorySizeNames[memoryIdx]; - Expression* boundsCheck = builder.makeIf( - builder.makeBinary( - Abstract::getBinary(parent.pointerType, Abstract::GtU), - builder.makeBinary( - // ptr + offset (ea from wasm spec) + bit width - // two builder Adds, we'll add the first two operands in the first - // add and then add the third operand in the second add - Abstract::getBinary(parent.pointerType, Abstract::Add), - builder.makeBinary( - Abstract::getBinary(parent.pointerType, Abstract::Add), - builder.makeLocalGet(ptrIdx, parent.pointerType), - builder.makeConstPtr(curr->offset, parent.pointerType)), - builder.makeConstPtr(bytes, parent.pointerType)), - builder.makeCall(memorySizeFunc, {}, parent.pointerType)), - builder.makeUnreachable()); - return boundsCheck; - } - template<typename T> void setMemory(T* curr) { curr->memory = parent.combinedMemory; } void visitLoad(Load* curr) { - curr->ptr = getPtr(curr, getFunction(), curr->bytes); + curr->ptr = getPtr(curr, curr->bytes); setMemory(curr); } void visitStore(Store* curr) { - curr->ptr = getPtr(curr, getFunction(), curr->bytes); + curr->ptr = getPtr(curr, curr->bytes); setMemory(curr); } void visitSIMDLoad(SIMDLoad* curr) { - curr->ptr = getPtr(curr, getFunction(), curr->getMemBytes()); + curr->ptr = getPtr(curr, curr->getMemBytes()); setMemory(curr); } void visitSIMDLoadSplat(SIMDLoad* curr) { - curr->ptr = getPtr(curr, getFunction(), curr->getMemBytes()); + curr->ptr = getPtr(curr, curr->getMemBytes()); setMemory(curr); } void visitSIMDLoadExtend(SIMDLoad* curr) { - curr->ptr = getPtr(curr, getFunction(), curr->getMemBytes()); + curr->ptr = getPtr(curr, curr->getMemBytes()); setMemory(curr); } void visitSIMDLoadZero(SIMDLoad* curr) { - curr->ptr = getPtr(curr, getFunction(), curr->getMemBytes()); + curr->ptr = getPtr(curr, curr->getMemBytes()); setMemory(curr); } void visitSIMDLoadStoreLane(SIMDLoadStoreLane* curr) { - curr->ptr = getPtr(curr, getFunction(), curr->getMemBytes()); + curr->ptr = getPtr(curr, curr->getMemBytes()); setMemory(curr); } void visitAtomicRMW(AtomicRMW* curr) { - curr->ptr = getPtr(curr, getFunction(), curr->bytes); + curr->ptr = getPtr(curr, curr->bytes); setMemory(curr); } void visitAtomicCmpxchg(AtomicCmpxchg* curr) { - curr->ptr = getPtr(curr, getFunction(), curr->bytes); + curr->ptr = getPtr(curr, curr->bytes); setMemory(curr); } @@ -226,12 +246,12 @@ struct MultiMemoryLowering : public Pass { default: WASM_UNREACHABLE("unexpected type"); } - curr->ptr = getPtr(curr, getFunction(), bytes); + curr->ptr = getPtr(curr, bytes); setMemory(curr); } void visitAtomicNotify(AtomicNotify* curr) { - curr->ptr = getPtr(curr, getFunction(), Index(4)); + curr->ptr = getPtr(curr, Index(4)); setMemory(curr); } }; @@ -247,7 +267,7 @@ struct MultiMemoryLowering : public Pass { this->wasm = module; prepCombinedMemory(); - addOffsetGlobals(); + makeOffsetGlobals(); adjustActiveDataSegmentOffsets(); createMemorySizeFunctions(); createMemoryGrowFunctions(); @@ -310,7 +330,7 @@ struct MultiMemoryLowering : public Pass { combinedMemory = Names::getValidMemoryName(*wasm, "combined_memory"); } - void addOffsetGlobals() { + void makeOffsetGlobals() { auto addGlobal = [&](Name name, size_t offset) { auto global = Builder::makeGlobal( name, |