summaryrefslogtreecommitdiff
path: root/src/passes/SafeHeap.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/passes/SafeHeap.cpp')
-rw-r--r--src/passes/SafeHeap.cpp110
1 files changed, 73 insertions, 37 deletions
diff --git a/src/passes/SafeHeap.cpp b/src/passes/SafeHeap.cpp
index 068c8ef73..eccb521d2 100644
--- a/src/passes/SafeHeap.cpp
+++ b/src/passes/SafeHeap.cpp
@@ -82,10 +82,11 @@ struct AccessInstrumenter : public WalkerPass<PostWalker<AccessInstrumenter>> {
return;
}
Builder builder(*getModule());
- replaceCurrent(
- builder.makeCall(getLoadName(curr),
- {curr->ptr, builder.makeConstPtr(curr->offset.addr)},
- curr->type));
+ auto memory = getModule()->getMemory(curr->memory);
+ replaceCurrent(builder.makeCall(
+ getLoadName(curr),
+ {curr->ptr, builder.makeConstPtr(curr->offset.addr, memory->indexType)},
+ curr->type));
}
void visitStore(Store* curr) {
@@ -94,9 +95,12 @@ struct AccessInstrumenter : public WalkerPass<PostWalker<AccessInstrumenter>> {
return;
}
Builder builder(*getModule());
+ auto memory = getModule()->getMemory(curr->memory);
replaceCurrent(builder.makeCall(
getStoreName(curr),
- {curr->ptr, builder.makeConstPtr(curr->offset.addr), curr->value},
+ {curr->ptr,
+ builder.makeConstPtr(curr->offset.addr, memory->indexType),
+ curr->value},
Type::none));
}
};
@@ -131,6 +135,7 @@ struct SafeHeap : public Pass {
void run(PassRunner* runner, Module* module) override {
options = runner->options;
+ assert(!module->memories.empty());
// add imports
addImports(module);
// instrument loads and stores
@@ -151,7 +156,7 @@ struct SafeHeap : public Pass {
void addImports(Module* module) {
ImportInfo info(*module);
- auto indexType = module->memory.indexType;
+ auto indexType = module->memories[0]->indexType;
if (auto* existing = info.getImportedFunction(ENV, GET_SBRK_PTR)) {
getSbrkPtr = existing->name;
} else if (auto* existing = module->getExportOrNull(GET_SBRK_PTR)) {
@@ -202,6 +207,7 @@ struct SafeHeap : public Pass {
continue;
}
load.type = type;
+ load.memory = module->memories[0]->name;
for (Index bytes : {1, 2, 4, 8, 16}) {
load.bytes = bytes;
if (bytes > type.getByteSize() || (type == Type::f32 && bytes != 4) ||
@@ -221,8 +227,9 @@ struct SafeHeap : public Pass {
}
for (auto isAtomic : {true, false}) {
load.isAtomic = isAtomic;
- if (isAtomic && !isPossibleAtomicOperation(
- align, bytes, module->memory.shared, type)) {
+ if (isAtomic &&
+ !isPossibleAtomicOperation(
+ align, bytes, module->memories[0]->shared, type)) {
continue;
}
addLoadFunc(load, module);
@@ -240,6 +247,7 @@ struct SafeHeap : public Pass {
}
store.valueType = valueType;
store.type = Type::none;
+ store.memory = module->memories[0]->name;
for (Index bytes : {1, 2, 4, 8, 16}) {
store.bytes = bytes;
if (bytes > valueType.getByteSize() ||
@@ -255,8 +263,9 @@ struct SafeHeap : public Pass {
}
for (auto isAtomic : {true, false}) {
store.isAtomic = isAtomic;
- if (isAtomic && !isPossibleAtomicOperation(
- align, bytes, module->memory.shared, valueType)) {
+ if (isAtomic &&
+ !isPossibleAtomicOperation(
+ align, bytes, module->memories[0]->shared, valueType)) {
continue;
}
addStoreFunc(store, module);
@@ -273,22 +282,30 @@ struct SafeHeap : public Pass {
return;
}
// pointer, offset
- auto indexType = module->memory.indexType;
+ auto memory = module->getMemory(style.memory);
+ auto indexType = memory->indexType;
auto funcSig = Signature({indexType, indexType}, style.type);
auto func = Builder::makeFunction(name, funcSig, {indexType});
Builder builder(*module);
auto* block = builder.makeBlock();
block->list.push_back(builder.makeLocalSet(
2,
- builder.makeBinary(module->memory.is64() ? AddInt64 : AddInt32,
+ builder.makeBinary(memory->is64() ? AddInt64 : AddInt32,
builder.makeLocalGet(0, indexType),
builder.makeLocalGet(1, indexType))));
// check for reading past valid memory: if pointer + offset + bytes
- block->list.push_back(
- makeBoundsCheck(style.type, builder, 2, style.bytes, module));
+ block->list.push_back(makeBoundsCheck(style.type,
+ builder,
+ 2,
+ style.bytes,
+ module,
+ memory->indexType,
+ memory->is64(),
+ memory->name));
// check proper alignment
if (style.align > 1) {
- block->list.push_back(makeAlignCheck(style.align, builder, 2, module));
+ block->list.push_back(
+ makeAlignCheck(style.align, builder, 2, module, memory->name));
}
// do the load
auto* load = module->allocator.alloc<Load>();
@@ -312,7 +329,9 @@ struct SafeHeap : public Pass {
if (module->getFunctionOrNull(name)) {
return;
}
- auto indexType = module->memory.indexType;
+ auto memory = module->getMemory(style.memory);
+ auto indexType = memory->indexType;
+ bool is64 = memory->is64();
// pointer, offset, value
auto funcSig =
Signature({indexType, indexType, style.valueType}, Type::none);
@@ -321,19 +340,27 @@ struct SafeHeap : public Pass {
auto* block = builder.makeBlock();
block->list.push_back(builder.makeLocalSet(
3,
- builder.makeBinary(module->memory.is64() ? AddInt64 : AddInt32,
+ builder.makeBinary(is64 ? AddInt64 : AddInt32,
builder.makeLocalGet(0, indexType),
builder.makeLocalGet(1, indexType))));
// check for reading past valid memory: if pointer + offset + bytes
- block->list.push_back(
- makeBoundsCheck(style.valueType, builder, 3, style.bytes, module));
+ block->list.push_back(makeBoundsCheck(style.valueType,
+ builder,
+ 3,
+ style.bytes,
+ module,
+ indexType,
+ is64,
+ memory->name));
// check proper alignment
if (style.align > 1) {
- block->list.push_back(makeAlignCheck(style.align, builder, 3, module));
+ block->list.push_back(
+ makeAlignCheck(style.align, builder, 3, module, memory->name));
}
// do the store
auto* store = module->allocator.alloc<Store>();
*store = style; // basically the same as the template we are given!
+ store->memory = memory->name;
store->ptr = builder.makeLocalGet(3, indexType);
store->value = builder.makeLocalGet(2, style.valueType);
block->list.push_back(store);
@@ -342,11 +369,15 @@ struct SafeHeap : public Pass {
module->addFunction(std::move(func));
}
- Expression*
- makeAlignCheck(Address align, Builder& builder, Index local, Module* module) {
- auto indexType = module->memory.indexType;
+ Expression* makeAlignCheck(Address align,
+ Builder& builder,
+ Index local,
+ Module* module,
+ Name memoryName) {
+ auto memory = module->getMemory(memoryName);
+ auto indexType = memory->indexType;
Expression* ptrBits = builder.makeLocalGet(local, indexType);
- if (module->memory.is64()) {
+ if (memory->is64()) {
ptrBits = builder.makeUnary(WrapInt64, ptrBits);
}
return builder.makeIf(
@@ -355,17 +386,21 @@ struct SafeHeap : public Pass {
builder.makeCall(alignfault, {}, Type::none));
}
- Expression* makeBoundsCheck(
- Type type, Builder& builder, Index local, Index bytes, Module* module) {
- auto indexType = module->memory.indexType;
- auto upperOp = module->memory.is64()
- ? options.lowMemoryUnused ? LtUInt64 : EqInt64
- : options.lowMemoryUnused ? LtUInt32 : EqInt32;
+ Expression* makeBoundsCheck(Type type,
+ Builder& builder,
+ Index local,
+ Index bytes,
+ Module* module,
+ Type indexType,
+ bool is64,
+ Name memory) {
+ auto upperOp = is64 ? options.lowMemoryUnused ? LtUInt64 : EqInt64
+ : options.lowMemoryUnused ? LtUInt32 : EqInt32;
auto upperBound = options.lowMemoryUnused ? PassOptions::LowMemoryBound : 0;
Expression* brkLocation;
if (sbrk.is()) {
brkLocation =
- builder.makeCall(sbrk, {builder.makeConstPtr(0)}, indexType);
+ builder.makeCall(sbrk, {builder.makeConstPtr(0, indexType)}, indexType);
} else {
Expression* sbrkPtr;
if (dynamicTopPtr.is()) {
@@ -373,22 +408,23 @@ struct SafeHeap : public Pass {
} else {
sbrkPtr = builder.makeCall(getSbrkPtr, {}, indexType);
}
- auto size = module->memory.is64() ? 8 : 4;
- brkLocation = builder.makeLoad(size, false, 0, size, sbrkPtr, indexType);
+ auto size = is64 ? 8 : 4;
+ brkLocation =
+ builder.makeLoad(size, false, 0, size, sbrkPtr, indexType, memory);
}
- auto gtuOp = module->memory.is64() ? GtUInt64 : GtUInt32;
- auto addOp = module->memory.is64() ? AddInt64 : AddInt32;
+ auto gtuOp = is64 ? GtUInt64 : GtUInt32;
+ auto addOp = is64 ? AddInt64 : AddInt32;
return builder.makeIf(
builder.makeBinary(
OrInt32,
builder.makeBinary(upperOp,
builder.makeLocalGet(local, indexType),
- builder.makeConstPtr(upperBound)),
+ builder.makeConstPtr(upperBound, indexType)),
builder.makeBinary(
gtuOp,
builder.makeBinary(addOp,
builder.makeLocalGet(local, indexType),
- builder.makeConstPtr(bytes)),
+ builder.makeConstPtr(bytes, indexType)),
brkLocation)),
builder.makeCall(segfault, {}, Type::none));
}