diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/passes/StackCheck.cpp | 77 |
1 files changed, 53 insertions, 24 deletions
diff --git a/src/passes/StackCheck.cpp b/src/passes/StackCheck.cpp index 4c797351b..183420fba 100644 --- a/src/passes/StackCheck.cpp +++ b/src/passes/StackCheck.cpp @@ -31,8 +31,16 @@ namespace wasm { +// The base is where the stack begins. As it goes down, that is the highest +// valid address. +static Name STACK_BASE("__stack_base"); +// The limit is the farthest it can grow to, which is the lowest valid address. static Name STACK_LIMIT("__stack_limit"); +// Old version, which sets the limit. +// TODO: remove this static Name SET_STACK_LIMIT("__set_stack_limit"); +// New version, which sets the base and the limit. +static Name SET_STACK_LIMITS("__set_stack_limits"); static void importStackOverflowHandler(Module& module, Name name) { ImportInfo info(module); @@ -55,34 +63,43 @@ static void addExportedFunction(Module& module, Function* function) { module.addExport(export_); } -static void generateSetStackLimitFunction(Module& module) { +static void generateSetStackLimitFunctions(Module& module) { Builder builder(module); - Function* function = + // One-parameter version + Function* limitFunc = builder.makeFunction(SET_STACK_LIMIT, Signature(Type::i32, Type::none), {}); LocalGet* getArg = builder.makeLocalGet(0, Type::i32); Expression* store = builder.makeGlobalSet(STACK_LIMIT, getArg); - function->body = store; - addExportedFunction(module, function); + limitFunc->body = store; + addExportedFunction(module, limitFunc); + // Two-parameter version + Function* limitsFunc = builder.makeFunction( + SET_STACK_LIMITS, Signature({Type::i32, Type::i32}, Type::none), {}); + LocalGet* getBase = builder.makeLocalGet(0, Type::i32); + Expression* storeBase = builder.makeGlobalSet(STACK_BASE, getBase); + LocalGet* getLimit = builder.makeLocalGet(1, Type::i32); + Expression* storeLimit = builder.makeGlobalSet(STACK_LIMIT, getLimit); + limitsFunc->body = builder.makeBlock({storeBase, storeLimit}); + addExportedFunction(module, limitsFunc); } -struct EnforceStackLimit : public WalkerPass<PostWalker<EnforceStackLimit>> { - EnforceStackLimit(Global* stackPointer, - Global* stackLimit, - Builder& builder, - Name handler) - : stackPointer(stackPointer), stackLimit(stackLimit), builder(builder), - handler(handler) {} +struct EnforceStackLimits : public WalkerPass<PostWalker<EnforceStackLimits>> { + EnforceStackLimits(Global* stackPointer, + Global* stackBase, + Global* stackLimit, + Builder& builder, + Name handler) + : stackPointer(stackPointer), stackBase(stackBase), stackLimit(stackLimit), + builder(builder), handler(handler) {} bool isFunctionParallel() override { return true; } Pass* create() override { - return new EnforceStackLimit(stackPointer, stackLimit, builder, handler); + return new EnforceStackLimits( + stackPointer, stackBase, stackLimit, builder, handler); } - Expression* stackBoundsCheck(Function* func, - Expression* value, - Global* stackPointer, - Global* stackLimit) { + Expression* stackBoundsCheck(Function* func, Expression* value) { // Add a local to store the value of the expression. We need the value // twice: once to check if it has overflowed, and again to assign to store // it. @@ -95,12 +112,18 @@ struct EnforceStackLimit : public WalkerPass<PostWalker<EnforceStackLimit>> { } else { handlerExpr = builder.makeUnreachable(); } - // (if (i32.lt_u (local.tee $newSP (...val...)) (global.get $__stack_limit)) + // If it is >= the base or <= the limit, then error. auto check = builder.makeIf( builder.makeBinary( - BinaryOp::LtUInt32, - builder.makeLocalTee(newSP, value, stackPointer->type), - builder.makeGlobalGet(stackLimit->name, stackLimit->type)), + BinaryOp::OrInt32, + builder.makeBinary( + BinaryOp::GtUInt32, + builder.makeLocalTee(newSP, value, stackPointer->type), + builder.makeGlobalGet(stackBase->name, stackBase->type)), + builder.makeBinary( + BinaryOp::LtUInt32, + builder.makeLocalGet(newSP, stackPointer->type), + builder.makeGlobalGet(stackLimit->name, stackLimit->type))), handlerExpr); // (global.set $__stack_pointer (local.get $newSP)) auto newSet = builder.makeGlobalSet( @@ -110,13 +133,13 @@ struct EnforceStackLimit : public WalkerPass<PostWalker<EnforceStackLimit>> { void visitGlobalSet(GlobalSet* curr) { if (getModule()->getGlobalOrNull(curr->name) == stackPointer) { - replaceCurrent( - stackBoundsCheck(getFunction(), curr->value, stackPointer, stackLimit)); + replaceCurrent(stackBoundsCheck(getFunction(), curr->value)); } } private: Global* stackPointer; + Global* stackBase; Global* stackLimit; Builder& builder; Name handler; @@ -139,6 +162,12 @@ struct StackCheck : public Pass { } Builder builder(*module); + Global* stackBase = builder.makeGlobal(STACK_BASE, + stackPointer->type, + builder.makeConst(int32_t(0)), + Builder::Mutable); + module->addGlobal(stackBase); + Global* stackLimit = builder.makeGlobal(STACK_LIMIT, stackPointer->type, builder.makeConst(int32_t(0)), @@ -146,9 +175,9 @@ struct StackCheck : public Pass { module->addGlobal(stackLimit); PassRunner innerRunner(module); - EnforceStackLimit(stackPointer, stackLimit, builder, handler) + EnforceStackLimits(stackPointer, stackBase, stackLimit, builder, handler) .run(&innerRunner, module); - generateSetStackLimitFunction(*module); + generateSetStackLimitFunctions(*module); } }; |