diff options
Diffstat (limited to 'src/passes/SafeHeap.cpp')
-rw-r--r-- | src/passes/SafeHeap.cpp | 110 |
1 files changed, 73 insertions, 37 deletions
diff --git a/src/passes/SafeHeap.cpp b/src/passes/SafeHeap.cpp index 068c8ef73..eccb521d2 100644 --- a/src/passes/SafeHeap.cpp +++ b/src/passes/SafeHeap.cpp @@ -82,10 +82,11 @@ struct AccessInstrumenter : public WalkerPass<PostWalker<AccessInstrumenter>> { return; } Builder builder(*getModule()); - replaceCurrent( - builder.makeCall(getLoadName(curr), - {curr->ptr, builder.makeConstPtr(curr->offset.addr)}, - curr->type)); + auto memory = getModule()->getMemory(curr->memory); + replaceCurrent(builder.makeCall( + getLoadName(curr), + {curr->ptr, builder.makeConstPtr(curr->offset.addr, memory->indexType)}, + curr->type)); } void visitStore(Store* curr) { @@ -94,9 +95,12 @@ struct AccessInstrumenter : public WalkerPass<PostWalker<AccessInstrumenter>> { return; } Builder builder(*getModule()); + auto memory = getModule()->getMemory(curr->memory); replaceCurrent(builder.makeCall( getStoreName(curr), - {curr->ptr, builder.makeConstPtr(curr->offset.addr), curr->value}, + {curr->ptr, + builder.makeConstPtr(curr->offset.addr, memory->indexType), + curr->value}, Type::none)); } }; @@ -131,6 +135,7 @@ struct SafeHeap : public Pass { void run(PassRunner* runner, Module* module) override { options = runner->options; + assert(!module->memories.empty()); // add imports addImports(module); // instrument loads and stores @@ -151,7 +156,7 @@ struct SafeHeap : public Pass { void addImports(Module* module) { ImportInfo info(*module); - auto indexType = module->memory.indexType; + auto indexType = module->memories[0]->indexType; if (auto* existing = info.getImportedFunction(ENV, GET_SBRK_PTR)) { getSbrkPtr = existing->name; } else if (auto* existing = module->getExportOrNull(GET_SBRK_PTR)) { @@ -202,6 +207,7 @@ struct SafeHeap : public Pass { continue; } load.type = type; + load.memory = module->memories[0]->name; for (Index bytes : {1, 2, 4, 8, 16}) { load.bytes = bytes; if (bytes > type.getByteSize() || (type == Type::f32 && bytes != 4) || @@ -221,8 +227,9 @@ struct SafeHeap : public Pass { } for (auto isAtomic : {true, false}) { load.isAtomic = isAtomic; - if (isAtomic && !isPossibleAtomicOperation( - align, bytes, module->memory.shared, type)) { + if (isAtomic && + !isPossibleAtomicOperation( + align, bytes, module->memories[0]->shared, type)) { continue; } addLoadFunc(load, module); @@ -240,6 +247,7 @@ struct SafeHeap : public Pass { } store.valueType = valueType; store.type = Type::none; + store.memory = module->memories[0]->name; for (Index bytes : {1, 2, 4, 8, 16}) { store.bytes = bytes; if (bytes > valueType.getByteSize() || @@ -255,8 +263,9 @@ struct SafeHeap : public Pass { } for (auto isAtomic : {true, false}) { store.isAtomic = isAtomic; - if (isAtomic && !isPossibleAtomicOperation( - align, bytes, module->memory.shared, valueType)) { + if (isAtomic && + !isPossibleAtomicOperation( + align, bytes, module->memories[0]->shared, valueType)) { continue; } addStoreFunc(store, module); @@ -273,22 +282,30 @@ struct SafeHeap : public Pass { return; } // pointer, offset - auto indexType = module->memory.indexType; + auto memory = module->getMemory(style.memory); + auto indexType = memory->indexType; auto funcSig = Signature({indexType, indexType}, style.type); auto func = Builder::makeFunction(name, funcSig, {indexType}); Builder builder(*module); auto* block = builder.makeBlock(); block->list.push_back(builder.makeLocalSet( 2, - builder.makeBinary(module->memory.is64() ? AddInt64 : AddInt32, + builder.makeBinary(memory->is64() ? AddInt64 : AddInt32, builder.makeLocalGet(0, indexType), builder.makeLocalGet(1, indexType)))); // check for reading past valid memory: if pointer + offset + bytes - block->list.push_back( - makeBoundsCheck(style.type, builder, 2, style.bytes, module)); + block->list.push_back(makeBoundsCheck(style.type, + builder, + 2, + style.bytes, + module, + memory->indexType, + memory->is64(), + memory->name)); // check proper alignment if (style.align > 1) { - block->list.push_back(makeAlignCheck(style.align, builder, 2, module)); + block->list.push_back( + makeAlignCheck(style.align, builder, 2, module, memory->name)); } // do the load auto* load = module->allocator.alloc<Load>(); @@ -312,7 +329,9 @@ struct SafeHeap : public Pass { if (module->getFunctionOrNull(name)) { return; } - auto indexType = module->memory.indexType; + auto memory = module->getMemory(style.memory); + auto indexType = memory->indexType; + bool is64 = memory->is64(); // pointer, offset, value auto funcSig = Signature({indexType, indexType, style.valueType}, Type::none); @@ -321,19 +340,27 @@ struct SafeHeap : public Pass { auto* block = builder.makeBlock(); block->list.push_back(builder.makeLocalSet( 3, - builder.makeBinary(module->memory.is64() ? AddInt64 : AddInt32, + builder.makeBinary(is64 ? AddInt64 : AddInt32, builder.makeLocalGet(0, indexType), builder.makeLocalGet(1, indexType)))); // check for reading past valid memory: if pointer + offset + bytes - block->list.push_back( - makeBoundsCheck(style.valueType, builder, 3, style.bytes, module)); + block->list.push_back(makeBoundsCheck(style.valueType, + builder, + 3, + style.bytes, + module, + indexType, + is64, + memory->name)); // check proper alignment if (style.align > 1) { - block->list.push_back(makeAlignCheck(style.align, builder, 3, module)); + block->list.push_back( + makeAlignCheck(style.align, builder, 3, module, memory->name)); } // do the store auto* store = module->allocator.alloc<Store>(); *store = style; // basically the same as the template we are given! + store->memory = memory->name; store->ptr = builder.makeLocalGet(3, indexType); store->value = builder.makeLocalGet(2, style.valueType); block->list.push_back(store); @@ -342,11 +369,15 @@ struct SafeHeap : public Pass { module->addFunction(std::move(func)); } - Expression* - makeAlignCheck(Address align, Builder& builder, Index local, Module* module) { - auto indexType = module->memory.indexType; + Expression* makeAlignCheck(Address align, + Builder& builder, + Index local, + Module* module, + Name memoryName) { + auto memory = module->getMemory(memoryName); + auto indexType = memory->indexType; Expression* ptrBits = builder.makeLocalGet(local, indexType); - if (module->memory.is64()) { + if (memory->is64()) { ptrBits = builder.makeUnary(WrapInt64, ptrBits); } return builder.makeIf( @@ -355,17 +386,21 @@ struct SafeHeap : public Pass { builder.makeCall(alignfault, {}, Type::none)); } - Expression* makeBoundsCheck( - Type type, Builder& builder, Index local, Index bytes, Module* module) { - auto indexType = module->memory.indexType; - auto upperOp = module->memory.is64() - ? options.lowMemoryUnused ? LtUInt64 : EqInt64 - : options.lowMemoryUnused ? LtUInt32 : EqInt32; + Expression* makeBoundsCheck(Type type, + Builder& builder, + Index local, + Index bytes, + Module* module, + Type indexType, + bool is64, + Name memory) { + auto upperOp = is64 ? options.lowMemoryUnused ? LtUInt64 : EqInt64 + : options.lowMemoryUnused ? LtUInt32 : EqInt32; auto upperBound = options.lowMemoryUnused ? PassOptions::LowMemoryBound : 0; Expression* brkLocation; if (sbrk.is()) { brkLocation = - builder.makeCall(sbrk, {builder.makeConstPtr(0)}, indexType); + builder.makeCall(sbrk, {builder.makeConstPtr(0, indexType)}, indexType); } else { Expression* sbrkPtr; if (dynamicTopPtr.is()) { @@ -373,22 +408,23 @@ struct SafeHeap : public Pass { } else { sbrkPtr = builder.makeCall(getSbrkPtr, {}, indexType); } - auto size = module->memory.is64() ? 8 : 4; - brkLocation = builder.makeLoad(size, false, 0, size, sbrkPtr, indexType); + auto size = is64 ? 8 : 4; + brkLocation = + builder.makeLoad(size, false, 0, size, sbrkPtr, indexType, memory); } - auto gtuOp = module->memory.is64() ? GtUInt64 : GtUInt32; - auto addOp = module->memory.is64() ? AddInt64 : AddInt32; + auto gtuOp = is64 ? GtUInt64 : GtUInt32; + auto addOp = is64 ? AddInt64 : AddInt32; return builder.makeIf( builder.makeBinary( OrInt32, builder.makeBinary(upperOp, builder.makeLocalGet(local, indexType), - builder.makeConstPtr(upperBound)), + builder.makeConstPtr(upperBound, indexType)), builder.makeBinary( gtuOp, builder.makeBinary(addOp, builder.makeLocalGet(local, indexType), - builder.makeConstPtr(bytes)), + builder.makeConstPtr(bytes, indexType)), brkLocation)), builder.makeCall(segfault, {}, Type::none)); } |