diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/passes/MultiMemoryLowering.cpp | 112 |
1 files changed, 112 insertions, 0 deletions
diff --git a/src/passes/MultiMemoryLowering.cpp b/src/passes/MultiMemoryLowering.cpp index 1183a7d74..38a5ded93 100644 --- a/src/passes/MultiMemoryLowering.cpp +++ b/src/passes/MultiMemoryLowering.cpp @@ -170,6 +170,17 @@ struct MultiMemoryLowering : public Pass { return boundsCheck; } + Expression* makeDataSegmentBoundsCheck(MemoryInit* curr, + Index sizeIdx, + Index offsetIdx) { + auto& segment = parent.wasm->dataSegments[curr->segment]; + Expression* addGtuTrap = makeAddGtuTrap( + builder.makeLocalGet(offsetIdx, parent.pointerType), + builder.makeLocalGet(sizeIdx, parent.pointerType), + builder.makeConstPtr(segment->data.size(), parent.pointerType)); + return addGtuTrap; + } + template<typename T> Expression* getPtr(T* curr, Index bytes) { Expression* ptrValue = addOffsetGlobal(curr->ptr, curr->memory); if (parent.checkBounds) { @@ -183,6 +194,107 @@ struct MultiMemoryLowering : public Pass { return ptrValue; } + template<typename T> + Expression* getDest(T* curr, + Name memory, + Index sizeIdx = Index(-1), + Expression* localSet = nullptr, + Expression* additionalCheck = nullptr) { + Expression* destValue = addOffsetGlobal(curr->dest, memory); + + if (parent.checkBounds) { + Expression* sizeSet = builder.makeLocalSet(sizeIdx, curr->size); + Index destIdx = Builder::addVar(getFunction(), parent.pointerType); + Expression* destSet = builder.makeLocalSet(destIdx, destValue); + Expression* boundsCheck = makeAddGtuMemoryTrap( + builder.makeLocalGet(destIdx, parent.pointerType), + builder.makeLocalGet(sizeIdx, parent.pointerType), + memory); + std::vector<Expression*> exprs = { + destSet, localSet, sizeSet, boundsCheck}; + if (additionalCheck) { + exprs.push_back(additionalCheck); + } + Expression* destGet = builder.makeLocalGet(destIdx, parent.pointerType); + exprs.push_back(destGet); + return builder.makeBlock(exprs); + } + + return destValue; + } + + Expression* getSource(MemoryCopy* curr, + Index sizeIdx = Index(-1), + Index sourceIdx = Index(-1)) { + Expression* sourceValue = + addOffsetGlobal(curr->source, curr->sourceMemory); + + if (parent.checkBounds) { + Expression* boundsCheck = makeAddGtuMemoryTrap( + builder.makeLocalGet(sourceIdx, parent.pointerType), + builder.makeLocalGet(sizeIdx, parent.pointerType), + curr->sourceMemory); + Expression* sourceGet = + builder.makeLocalGet(sourceIdx, parent.pointerType); + std::vector<Expression*> exprs = {boundsCheck, sourceGet}; + return builder.makeBlock(exprs); + } + + return sourceValue; + } + + void visitMemoryInit(MemoryInit* curr) { + if (parent.checkBounds) { + Index offsetIdx = Builder::addVar(getFunction(), parent.pointerType); + Index sizeIdx = Builder::addVar(getFunction(), parent.pointerType); + curr->dest = + getDest(curr, + curr->memory, + sizeIdx, + builder.makeLocalSet(offsetIdx, curr->offset), + makeDataSegmentBoundsCheck(curr, sizeIdx, offsetIdx)); + curr->offset = builder.makeLocalGet(offsetIdx, parent.pointerType); + curr->size = builder.makeLocalGet(sizeIdx, parent.pointerType); + } else { + curr->dest = getDest(curr, curr->memory); + } + setMemory(curr); + } + + void visitMemoryCopy(MemoryCopy* curr) { + if (parent.checkBounds) { + Index sourceIdx = Builder::addVar(getFunction(), parent.pointerType); + Index sizeIdx = Builder::addVar(getFunction(), parent.pointerType); + curr->dest = getDest(curr, + curr->destMemory, + sizeIdx, + builder.makeLocalSet(sourceIdx, curr->source)); + curr->source = getSource(curr, sizeIdx, sourceIdx); + curr->size = builder.makeLocalGet(sizeIdx, parent.pointerType); + } else { + curr->dest = getDest(curr, curr->destMemory); + curr->source = getSource(curr); + } + curr->destMemory = parent.combinedMemory; + curr->sourceMemory = parent.combinedMemory; + } + + void visitMemoryFill(MemoryFill* curr) { + if (parent.checkBounds) { + Index valueIdx = Builder::addVar(getFunction(), parent.pointerType); + Index sizeIdx = Builder::addVar(getFunction(), parent.pointerType); + curr->dest = getDest(curr, + curr->memory, + sizeIdx, + builder.makeLocalSet(valueIdx, curr->value)); + curr->value = builder.makeLocalGet(valueIdx, parent.pointerType); + curr->size = builder.makeLocalGet(sizeIdx, parent.pointerType); + } else { + curr->dest = getDest(curr, curr->memory); + } + setMemory(curr); + } + template<typename T> void setMemory(T* curr) { curr->memory = parent.combinedMemory; } |