diff options
Diffstat (limited to 'src/passes/SafeHeap.cpp')
-rw-r--r-- | src/passes/SafeHeap.cpp | 56 |
1 files changed, 36 insertions, 20 deletions
diff --git a/src/passes/SafeHeap.cpp b/src/passes/SafeHeap.cpp index a3416afaf..a90696f2c 100644 --- a/src/passes/SafeHeap.cpp +++ b/src/passes/SafeHeap.cpp @@ -63,30 +63,21 @@ static Name getStoreName(Store* curr) { } struct AccessInstrumenter : public WalkerPass<PostWalker<AccessInstrumenter>> { - // If the getSbrkPtr function is implemented in the wasm, we must not - // instrument that, as it would lead to infinite recursion of it calling - // SAFE_HEAP_LOAD that calls it and so forth. - // As well as the getSbrkPtr function we also avoid instrumenting the - // module start function. This is because this function is used in - // shared memory builds to load the passive memory segments, which in - // turn means that value of sbrk() is not available. - Name getSbrkPtr; + // A set of function that we should ignore (not instrument). + std::set<Name> ignoreFunctions; bool isFunctionParallel() override { return true; } AccessInstrumenter* create() override { - return new AccessInstrumenter(getSbrkPtr); + return new AccessInstrumenter(ignoreFunctions); } - AccessInstrumenter(Name getSbrkPtr) : getSbrkPtr(getSbrkPtr) {} + AccessInstrumenter(std::set<Name> ignoreFunctions) + : ignoreFunctions(ignoreFunctions) {} void visitLoad(Load* curr) { - // As well as the getSbrkPtr function we also avoid insturmenting the - // module start function. This is because this function is used in - // shared memory builds to load the passive memory segments, which in - // turn means that value of sbrk() is not available. - if (getFunction()->name == getModule()->start || - getFunction()->name == getSbrkPtr || curr->type == Type::unreachable) { + if (ignoreFunctions.count(getFunction()->name) != 0 || + curr->type == Type::unreachable) { return; } Builder builder(*getModule()); @@ -97,8 +88,8 @@ struct AccessInstrumenter : public WalkerPass<PostWalker<AccessInstrumenter>> { } void visitStore(Store* curr) { - if (getFunction()->name == getModule()->start || - getFunction()->name == getSbrkPtr || curr->type == Type::unreachable) { + if (ignoreFunctions.count(getFunction()->name) != 0 || + curr->type == Type::unreachable) { return; } Builder builder(*getModule()); @@ -109,6 +100,12 @@ struct AccessInstrumenter : public WalkerPass<PostWalker<AccessInstrumenter>> { } }; +struct FindDirectCallees : public WalkerPass<PostWalker<FindDirectCallees>> { +public: + void visitCall(Call* curr) { callees.insert(curr->target); } + std::set<Name> callees; +}; + struct SafeHeap : public Pass { PassOptions options; @@ -117,12 +114,31 @@ struct SafeHeap : public Pass { // add imports addImports(module); // instrument loads and stores - AccessInstrumenter(getSbrkPtr).run(runner, module); + // We avoid instrumenting the module start function of any function + // that it directly calls. This is because in some cases the linker + // generates `__wasm_init_memory` (either as the start function or + // a function directly called from it) and this function is used in shared + // memory builds to load the passive memory segments, which in turn means + // that value of sbrk() is not available until after it has run. + std::set<Name> ignoreFunctions; + if (module->start.is()) { + // Note that this only finds directly called functions, not transitively + // called ones. That is enough given the current LLVM output as start + // will only contain very specific, linker-generated code + // (__wasm_init_memory etc. as mentioned above). + FindDirectCallees findDirectCallees; + findDirectCallees.walkFunctionInModule(module->getFunction(module->start), + module); + ignoreFunctions = findDirectCallees.callees; + ignoreFunctions.insert(module->start); + } + ignoreFunctions.insert(getSbrkPtr); + AccessInstrumenter(ignoreFunctions).run(runner, module); // add helper checking funcs and imports addGlobals(module, module->features); } - Name dynamicTopPtr, getSbrkPtr, sbrk, segfault, alignfault; + Name getSbrkPtr, dynamicTopPtr, sbrk, segfault, alignfault; void addImports(Module* module) { ImportInfo info(*module); |