diff options
Diffstat (limited to 'src/wasm/wasm-emscripten.cpp')
-rw-r--r-- | src/wasm/wasm-emscripten.cpp | 121 |
1 files changed, 118 insertions, 3 deletions
diff --git a/src/wasm/wasm-emscripten.cpp b/src/wasm/wasm-emscripten.cpp index 3ee3e4424..c21016baa 100644 --- a/src/wasm/wasm-emscripten.cpp +++ b/src/wasm/wasm-emscripten.cpp @@ -37,8 +37,11 @@ static Name STACK_SAVE("stackSave"); static Name STACK_RESTORE("stackRestore"); static Name STACK_ALLOC("stackAlloc"); static Name STACK_INIT("stack$init"); +static Name STACK_LIMIT("__stack_limit"); +static Name SET_STACK_LIMIT("__set_stack_limit"); static Name POST_INSTANTIATE("__post_instantiate"); static Name ASSIGN_GOT_ENTIRES("__assign_got_enties"); +static Name STACK_OVERFLOW_IMPORT("__handle_stack_overflow"); void addExportedFunction(Module& wasm, Function* function) { wasm.addFunction(function); @@ -92,8 +95,32 @@ Expression* EmscriptenGlueGenerator::generateLoadStackPointer() { return builder.makeGlobalGet(stackPointer->name, i32); } +inline Expression* stackBoundsCheck(Builder& builder, + Function* func, + Expression* value, + Global* stackPointer, + Global* stackLimit, + Name handler) { + // Add a local to store the value of the expression. We need the value twice: + // once to check if it has overflowed, and again to assign to store it. + auto newSP = Builder::addVar(func, stackPointer->type); + // (if (i32.lt_u (local.tee $newSP (...value...)) (global.get $__stack_limit)) + // (call $handler)) + auto check = + builder.makeIf(builder.makeBinary( + BinaryOp::LtUInt32, + builder.makeLocalTee(newSP, value), + builder.makeGlobalGet(stackLimit->name, stackLimit->type)), + builder.makeCall(handler, {}, none)); + // (global.set $__stack_pointer (local.get $newSP)) + auto newSet = builder.makeGlobalSet( + stackPointer->name, builder.makeLocalGet(newSP, stackPointer->type)); + return builder.blockify(check, newSet); +} + Expression* -EmscriptenGlueGenerator::generateStoreStackPointer(Expression* value) { +EmscriptenGlueGenerator::generateStoreStackPointer(Function* func, + Expression* value) { if (!useStackPointerGlobal) { return builder.makeStore( /* bytes =*/4, @@ -107,6 +134,14 @@ EmscriptenGlueGenerator::generateStoreStackPointer(Expression* value) { if (!stackPointer) { Fatal() << "stack pointer global not found"; } + if (auto* stackLimit = wasm.getGlobalOrNull(STACK_LIMIT)) { + return stackBoundsCheck(builder, + func, + value, + stackPointer, + stackLimit, + importStackOverflowHandler()); + } return builder.makeGlobalSet(stackPointer->name, value); } @@ -132,7 +167,7 @@ void EmscriptenGlueGenerator::generateStackAllocFunction() { Const* subConst = builder.makeConst(Literal(~bitMask)); Binary* maskedSub = builder.makeBinary(AndInt32, sub, subConst); LocalSet* teeStackLocal = builder.makeLocalTee(1, maskedSub); - Expression* storeStack = generateStoreStackPointer(teeStackLocal); + Expression* storeStack = generateStoreStackPointer(function, teeStackLocal); Block* block = builder.makeBlock(); block->list.push_back(storeStack); @@ -149,7 +184,7 @@ void EmscriptenGlueGenerator::generateStackRestoreFunction() { Function* function = builder.makeFunction(STACK_RESTORE, std::move(params), none, {}); LocalGet* getArg = builder.makeLocalGet(0, i32); - Expression* store = generateStoreStackPointer(getArg); + Expression* store = generateStoreStackPointer(function, getArg); function->body = store; @@ -444,6 +479,86 @@ void EmscriptenGlueGenerator::replaceStackPointerGlobal() { wasm.removeGlobal(stackPointer->name); } +struct StackLimitEnforcer : public WalkerPass<PostWalker<StackLimitEnforcer>> { + StackLimitEnforcer(Global* stackPointer, + Global* stackLimit, + Builder& builder, + Name handler) + : stackPointer(stackPointer), stackLimit(stackLimit), builder(builder), + handler(handler) {} + + bool isFunctionParallel() override { return true; } + + Pass* create() override { + return new StackLimitEnforcer(stackPointer, stackLimit, builder, handler); + } + + void visitGlobalSet(GlobalSet* curr) { + if (getModule()->getGlobalOrNull(curr->name) == stackPointer) { + replaceCurrent(stackBoundsCheck(builder, + getFunction(), + curr->value, + stackPointer, + stackLimit, + handler)); + } + } + +private: + Global* stackPointer; + Global* stackLimit; + Builder& builder; + Name handler; +}; + +void EmscriptenGlueGenerator::enforceStackLimit() { + Global* stackPointer = getStackPointerGlobal(); + if (!stackPointer) { + return; + } + + auto* stackLimit = builder.makeGlobal(STACK_LIMIT, + stackPointer->type, + builder.makeConst(Literal(0)), + Builder::Mutable); + wasm.addGlobal(stackLimit); + + auto handler = importStackOverflowHandler(); + + StackLimitEnforcer walker(stackPointer, stackLimit, builder, handler); + PassRunner runner(&wasm); + walker.run(&runner, &wasm); + + generateSetStackLimitFunction(); +} + +void EmscriptenGlueGenerator::generateSetStackLimitFunction() { + Function* function = + builder.makeFunction(SET_STACK_LIMIT, std::vector<Type>({i32}), none, {}); + LocalGet* getArg = builder.makeLocalGet(0, i32); + Expression* store = builder.makeGlobalSet(STACK_LIMIT, getArg); + function->body = store; + addExportedFunction(wasm, function); +} + +Name EmscriptenGlueGenerator::importStackOverflowHandler() { + ImportInfo info(wasm); + + if (auto* existing = info.getImportedFunction(ENV, STACK_OVERFLOW_IMPORT)) { + return existing->name; + } else { + auto* import = new Function; + import->name = STACK_OVERFLOW_IMPORT; + import->module = ENV; + import->base = STACK_OVERFLOW_IMPORT; + auto* functionType = ensureFunctionType("v", &wasm); + import->type = functionType->name; + FunctionTypeUtils::fillFunction(import, functionType); + wasm.addFunction(import); + return STACK_OVERFLOW_IMPORT; + } +} + const Address UNKNOWN_OFFSET(uint32_t(-1)); std::vector<Address> getSegmentOffsets(Module& wasm) { |