summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAshley Nelson <nashley@google.com>2022-12-15 13:33:14 -0800
committerGitHub <noreply@github.com>2022-12-15 13:33:14 -0800
commit7769196090e7ce2c150cd8a58fad0c89430d3d2b (patch)
tree141965a7ff40c96b5c9100283a9575995275ded1 /src
parent42896405d13d795c6c64ca5fad978bff025ff33a (diff)
downloadbinaryen-7769196090e7ce2c150cd8a58fad0c89430d3d2b.tar.gz
binaryen-7769196090e7ce2c150cd8a58fad0c89430d3d2b.tar.bz2
binaryen-7769196090e7ce2c150cd8a58fad0c89430d3d2b.zip
Add memory: init, copy, fill support to Multi-Memory Lowering Pass (#5346)
This PR adds support for memory.init, memory.copy, and memory.fill instructions in the multi-memory lowering pass. Also includes optional bounds checks per the wasm spec guidelines.
Diffstat (limited to 'src')
-rw-r--r--src/passes/MultiMemoryLowering.cpp112
1 files changed, 112 insertions, 0 deletions
diff --git a/src/passes/MultiMemoryLowering.cpp b/src/passes/MultiMemoryLowering.cpp
index 1183a7d74..38a5ded93 100644
--- a/src/passes/MultiMemoryLowering.cpp
+++ b/src/passes/MultiMemoryLowering.cpp
@@ -170,6 +170,17 @@ struct MultiMemoryLowering : public Pass {
return boundsCheck;
}
+ Expression* makeDataSegmentBoundsCheck(MemoryInit* curr,
+ Index sizeIdx,
+ Index offsetIdx) {
+ auto& segment = parent.wasm->dataSegments[curr->segment];
+ Expression* addGtuTrap = makeAddGtuTrap(
+ builder.makeLocalGet(offsetIdx, parent.pointerType),
+ builder.makeLocalGet(sizeIdx, parent.pointerType),
+ builder.makeConstPtr(segment->data.size(), parent.pointerType));
+ return addGtuTrap;
+ }
+
template<typename T> Expression* getPtr(T* curr, Index bytes) {
Expression* ptrValue = addOffsetGlobal(curr->ptr, curr->memory);
if (parent.checkBounds) {
@@ -183,6 +194,107 @@ struct MultiMemoryLowering : public Pass {
return ptrValue;
}
+ template<typename T>
+ Expression* getDest(T* curr,
+ Name memory,
+ Index sizeIdx = Index(-1),
+ Expression* localSet = nullptr,
+ Expression* additionalCheck = nullptr) {
+ Expression* destValue = addOffsetGlobal(curr->dest, memory);
+
+ if (parent.checkBounds) {
+ Expression* sizeSet = builder.makeLocalSet(sizeIdx, curr->size);
+ Index destIdx = Builder::addVar(getFunction(), parent.pointerType);
+ Expression* destSet = builder.makeLocalSet(destIdx, destValue);
+ Expression* boundsCheck = makeAddGtuMemoryTrap(
+ builder.makeLocalGet(destIdx, parent.pointerType),
+ builder.makeLocalGet(sizeIdx, parent.pointerType),
+ memory);
+ std::vector<Expression*> exprs = {
+ destSet, localSet, sizeSet, boundsCheck};
+ if (additionalCheck) {
+ exprs.push_back(additionalCheck);
+ }
+ Expression* destGet = builder.makeLocalGet(destIdx, parent.pointerType);
+ exprs.push_back(destGet);
+ return builder.makeBlock(exprs);
+ }
+
+ return destValue;
+ }
+
+ Expression* getSource(MemoryCopy* curr,
+ Index sizeIdx = Index(-1),
+ Index sourceIdx = Index(-1)) {
+ Expression* sourceValue =
+ addOffsetGlobal(curr->source, curr->sourceMemory);
+
+ if (parent.checkBounds) {
+ Expression* boundsCheck = makeAddGtuMemoryTrap(
+ builder.makeLocalGet(sourceIdx, parent.pointerType),
+ builder.makeLocalGet(sizeIdx, parent.pointerType),
+ curr->sourceMemory);
+ Expression* sourceGet =
+ builder.makeLocalGet(sourceIdx, parent.pointerType);
+ std::vector<Expression*> exprs = {boundsCheck, sourceGet};
+ return builder.makeBlock(exprs);
+ }
+
+ return sourceValue;
+ }
+
+ void visitMemoryInit(MemoryInit* curr) {
+ if (parent.checkBounds) {
+ Index offsetIdx = Builder::addVar(getFunction(), parent.pointerType);
+ Index sizeIdx = Builder::addVar(getFunction(), parent.pointerType);
+ curr->dest =
+ getDest(curr,
+ curr->memory,
+ sizeIdx,
+ builder.makeLocalSet(offsetIdx, curr->offset),
+ makeDataSegmentBoundsCheck(curr, sizeIdx, offsetIdx));
+ curr->offset = builder.makeLocalGet(offsetIdx, parent.pointerType);
+ curr->size = builder.makeLocalGet(sizeIdx, parent.pointerType);
+ } else {
+ curr->dest = getDest(curr, curr->memory);
+ }
+ setMemory(curr);
+ }
+
+ void visitMemoryCopy(MemoryCopy* curr) {
+ if (parent.checkBounds) {
+ Index sourceIdx = Builder::addVar(getFunction(), parent.pointerType);
+ Index sizeIdx = Builder::addVar(getFunction(), parent.pointerType);
+ curr->dest = getDest(curr,
+ curr->destMemory,
+ sizeIdx,
+ builder.makeLocalSet(sourceIdx, curr->source));
+ curr->source = getSource(curr, sizeIdx, sourceIdx);
+ curr->size = builder.makeLocalGet(sizeIdx, parent.pointerType);
+ } else {
+ curr->dest = getDest(curr, curr->destMemory);
+ curr->source = getSource(curr);
+ }
+ curr->destMemory = parent.combinedMemory;
+ curr->sourceMemory = parent.combinedMemory;
+ }
+
+ void visitMemoryFill(MemoryFill* curr) {
+ if (parent.checkBounds) {
+ Index valueIdx = Builder::addVar(getFunction(), parent.pointerType);
+ Index sizeIdx = Builder::addVar(getFunction(), parent.pointerType);
+ curr->dest = getDest(curr,
+ curr->memory,
+ sizeIdx,
+ builder.makeLocalSet(valueIdx, curr->value));
+ curr->value = builder.makeLocalGet(valueIdx, parent.pointerType);
+ curr->size = builder.makeLocalGet(sizeIdx, parent.pointerType);
+ } else {
+ curr->dest = getDest(curr, curr->memory);
+ }
+ setMemory(curr);
+ }
+
template<typename T> void setMemory(T* curr) {
curr->memory = parent.combinedMemory;
}