summaryrefslogtreecommitdiff
path: root/src/passes/SafeHeap.cpp
diff options
context:
space:
mode:
authorAshley Nelson <nashley@google.com>2022-08-17 18:44:29 -0700
committerGitHub <noreply@github.com>2022-08-17 18:44:29 -0700
commit3aff4c6e85623c970280219c6699a66bc9de5f9b (patch)
treee5440bc966e523a7404ae2cec3458dacbe1281d1 /src/passes/SafeHeap.cpp
parentb70fe755aa4c90727edfd91dc0a9a51febf0239d (diff)
downloadbinaryen-3aff4c6e85623c970280219c6699a66bc9de5f9b.tar.gz
binaryen-3aff4c6e85623c970280219c6699a66bc9de5f9b.tar.bz2
binaryen-3aff4c6e85623c970280219c6699a66bc9de5f9b.zip
Mutli-Memories Support in IR (#4811)
This PR removes the single memory restriction in IR, adding support for a single module to reference multiple memories. To support this change, a new memory name field was added to 13 memory instructions in order to identify the memory for the instruction. It is a goal of this PR to maintain backwards compatibility with existing text and binary wasm modules, so memory indexes remain optional for memory instructions. Similarly, the JS API makes assumptions about which memory is intended when only one memory is present in the module. Another goal of this PR is that existing tests behavior be unaffected. That said, tests must now explicitly define a memory before invoking memory instructions or exporting a memory, and memory names are now printed for each memory instruction in the text format. There remain quite a few places where a hardcoded reference to the first memory persist (memory flattening, for example, will return early if more than one memory is present in the module). Many of these call-sites, particularly within passes, will require us to rethink how the optimization works in a multi-memories world. Other call-sites may necessitate more invasive code restructuring to fully convert away from relying on a globally available, single memory pointer.
Diffstat (limited to 'src/passes/SafeHeap.cpp')
-rw-r--r--src/passes/SafeHeap.cpp110
1 files changed, 73 insertions, 37 deletions
diff --git a/src/passes/SafeHeap.cpp b/src/passes/SafeHeap.cpp
index 068c8ef73..eccb521d2 100644
--- a/src/passes/SafeHeap.cpp
+++ b/src/passes/SafeHeap.cpp
@@ -82,10 +82,11 @@ struct AccessInstrumenter : public WalkerPass<PostWalker<AccessInstrumenter>> {
return;
}
Builder builder(*getModule());
- replaceCurrent(
- builder.makeCall(getLoadName(curr),
- {curr->ptr, builder.makeConstPtr(curr->offset.addr)},
- curr->type));
+ auto memory = getModule()->getMemory(curr->memory);
+ replaceCurrent(builder.makeCall(
+ getLoadName(curr),
+ {curr->ptr, builder.makeConstPtr(curr->offset.addr, memory->indexType)},
+ curr->type));
}
void visitStore(Store* curr) {
@@ -94,9 +95,12 @@ struct AccessInstrumenter : public WalkerPass<PostWalker<AccessInstrumenter>> {
return;
}
Builder builder(*getModule());
+ auto memory = getModule()->getMemory(curr->memory);
replaceCurrent(builder.makeCall(
getStoreName(curr),
- {curr->ptr, builder.makeConstPtr(curr->offset.addr), curr->value},
+ {curr->ptr,
+ builder.makeConstPtr(curr->offset.addr, memory->indexType),
+ curr->value},
Type::none));
}
};
@@ -131,6 +135,7 @@ struct SafeHeap : public Pass {
void run(PassRunner* runner, Module* module) override {
options = runner->options;
+ assert(!module->memories.empty());
// add imports
addImports(module);
// instrument loads and stores
@@ -151,7 +156,7 @@ struct SafeHeap : public Pass {
void addImports(Module* module) {
ImportInfo info(*module);
- auto indexType = module->memory.indexType;
+ auto indexType = module->memories[0]->indexType;
if (auto* existing = info.getImportedFunction(ENV, GET_SBRK_PTR)) {
getSbrkPtr = existing->name;
} else if (auto* existing = module->getExportOrNull(GET_SBRK_PTR)) {
@@ -202,6 +207,7 @@ struct SafeHeap : public Pass {
continue;
}
load.type = type;
+ load.memory = module->memories[0]->name;
for (Index bytes : {1, 2, 4, 8, 16}) {
load.bytes = bytes;
if (bytes > type.getByteSize() || (type == Type::f32 && bytes != 4) ||
@@ -221,8 +227,9 @@ struct SafeHeap : public Pass {
}
for (auto isAtomic : {true, false}) {
load.isAtomic = isAtomic;
- if (isAtomic && !isPossibleAtomicOperation(
- align, bytes, module->memory.shared, type)) {
+ if (isAtomic &&
+ !isPossibleAtomicOperation(
+ align, bytes, module->memories[0]->shared, type)) {
continue;
}
addLoadFunc(load, module);
@@ -240,6 +247,7 @@ struct SafeHeap : public Pass {
}
store.valueType = valueType;
store.type = Type::none;
+ store.memory = module->memories[0]->name;
for (Index bytes : {1, 2, 4, 8, 16}) {
store.bytes = bytes;
if (bytes > valueType.getByteSize() ||
@@ -255,8 +263,9 @@ struct SafeHeap : public Pass {
}
for (auto isAtomic : {true, false}) {
store.isAtomic = isAtomic;
- if (isAtomic && !isPossibleAtomicOperation(
- align, bytes, module->memory.shared, valueType)) {
+ if (isAtomic &&
+ !isPossibleAtomicOperation(
+ align, bytes, module->memories[0]->shared, valueType)) {
continue;
}
addStoreFunc(store, module);
@@ -273,22 +282,30 @@ struct SafeHeap : public Pass {
return;
}
// pointer, offset
- auto indexType = module->memory.indexType;
+ auto memory = module->getMemory(style.memory);
+ auto indexType = memory->indexType;
auto funcSig = Signature({indexType, indexType}, style.type);
auto func = Builder::makeFunction(name, funcSig, {indexType});
Builder builder(*module);
auto* block = builder.makeBlock();
block->list.push_back(builder.makeLocalSet(
2,
- builder.makeBinary(module->memory.is64() ? AddInt64 : AddInt32,
+ builder.makeBinary(memory->is64() ? AddInt64 : AddInt32,
builder.makeLocalGet(0, indexType),
builder.makeLocalGet(1, indexType))));
// check for reading past valid memory: if pointer + offset + bytes
- block->list.push_back(
- makeBoundsCheck(style.type, builder, 2, style.bytes, module));
+ block->list.push_back(makeBoundsCheck(style.type,
+ builder,
+ 2,
+ style.bytes,
+ module,
+ memory->indexType,
+ memory->is64(),
+ memory->name));
// check proper alignment
if (style.align > 1) {
- block->list.push_back(makeAlignCheck(style.align, builder, 2, module));
+ block->list.push_back(
+ makeAlignCheck(style.align, builder, 2, module, memory->name));
}
// do the load
auto* load = module->allocator.alloc<Load>();
@@ -312,7 +329,9 @@ struct SafeHeap : public Pass {
if (module->getFunctionOrNull(name)) {
return;
}
- auto indexType = module->memory.indexType;
+ auto memory = module->getMemory(style.memory);
+ auto indexType = memory->indexType;
+ bool is64 = memory->is64();
// pointer, offset, value
auto funcSig =
Signature({indexType, indexType, style.valueType}, Type::none);
@@ -321,19 +340,27 @@ struct SafeHeap : public Pass {
auto* block = builder.makeBlock();
block->list.push_back(builder.makeLocalSet(
3,
- builder.makeBinary(module->memory.is64() ? AddInt64 : AddInt32,
+ builder.makeBinary(is64 ? AddInt64 : AddInt32,
builder.makeLocalGet(0, indexType),
builder.makeLocalGet(1, indexType))));
// check for reading past valid memory: if pointer + offset + bytes
- block->list.push_back(
- makeBoundsCheck(style.valueType, builder, 3, style.bytes, module));
+ block->list.push_back(makeBoundsCheck(style.valueType,
+ builder,
+ 3,
+ style.bytes,
+ module,
+ indexType,
+ is64,
+ memory->name));
// check proper alignment
if (style.align > 1) {
- block->list.push_back(makeAlignCheck(style.align, builder, 3, module));
+ block->list.push_back(
+ makeAlignCheck(style.align, builder, 3, module, memory->name));
}
// do the store
auto* store = module->allocator.alloc<Store>();
*store = style; // basically the same as the template we are given!
+ store->memory = memory->name;
store->ptr = builder.makeLocalGet(3, indexType);
store->value = builder.makeLocalGet(2, style.valueType);
block->list.push_back(store);
@@ -342,11 +369,15 @@ struct SafeHeap : public Pass {
module->addFunction(std::move(func));
}
- Expression*
- makeAlignCheck(Address align, Builder& builder, Index local, Module* module) {
- auto indexType = module->memory.indexType;
+ Expression* makeAlignCheck(Address align,
+ Builder& builder,
+ Index local,
+ Module* module,
+ Name memoryName) {
+ auto memory = module->getMemory(memoryName);
+ auto indexType = memory->indexType;
Expression* ptrBits = builder.makeLocalGet(local, indexType);
- if (module->memory.is64()) {
+ if (memory->is64()) {
ptrBits = builder.makeUnary(WrapInt64, ptrBits);
}
return builder.makeIf(
@@ -355,17 +386,21 @@ struct SafeHeap : public Pass {
builder.makeCall(alignfault, {}, Type::none));
}
- Expression* makeBoundsCheck(
- Type type, Builder& builder, Index local, Index bytes, Module* module) {
- auto indexType = module->memory.indexType;
- auto upperOp = module->memory.is64()
- ? options.lowMemoryUnused ? LtUInt64 : EqInt64
- : options.lowMemoryUnused ? LtUInt32 : EqInt32;
+ Expression* makeBoundsCheck(Type type,
+ Builder& builder,
+ Index local,
+ Index bytes,
+ Module* module,
+ Type indexType,
+ bool is64,
+ Name memory) {
+ auto upperOp = is64 ? options.lowMemoryUnused ? LtUInt64 : EqInt64
+ : options.lowMemoryUnused ? LtUInt32 : EqInt32;
auto upperBound = options.lowMemoryUnused ? PassOptions::LowMemoryBound : 0;
Expression* brkLocation;
if (sbrk.is()) {
brkLocation =
- builder.makeCall(sbrk, {builder.makeConstPtr(0)}, indexType);
+ builder.makeCall(sbrk, {builder.makeConstPtr(0, indexType)}, indexType);
} else {
Expression* sbrkPtr;
if (dynamicTopPtr.is()) {
@@ -373,22 +408,23 @@ struct SafeHeap : public Pass {
} else {
sbrkPtr = builder.makeCall(getSbrkPtr, {}, indexType);
}
- auto size = module->memory.is64() ? 8 : 4;
- brkLocation = builder.makeLoad(size, false, 0, size, sbrkPtr, indexType);
+ auto size = is64 ? 8 : 4;
+ brkLocation =
+ builder.makeLoad(size, false, 0, size, sbrkPtr, indexType, memory);
}
- auto gtuOp = module->memory.is64() ? GtUInt64 : GtUInt32;
- auto addOp = module->memory.is64() ? AddInt64 : AddInt32;
+ auto gtuOp = is64 ? GtUInt64 : GtUInt32;
+ auto addOp = is64 ? AddInt64 : AddInt32;
return builder.makeIf(
builder.makeBinary(
OrInt32,
builder.makeBinary(upperOp,
builder.makeLocalGet(local, indexType),
- builder.makeConstPtr(upperBound)),
+ builder.makeConstPtr(upperBound, indexType)),
builder.makeBinary(
gtuOp,
builder.makeBinary(addOp,
builder.makeLocalGet(local, indexType),
- builder.makeConstPtr(bytes)),
+ builder.makeConstPtr(bytes, indexType)),
brkLocation)),
builder.makeCall(segfault, {}, Type::none));
}