diff options
-rw-r--r-- | src/passes/StackCheck.cpp | 77 | ||||
-rw-r--r-- | test/lld/basic_safe_stack.wat.out | 50 | ||||
-rw-r--r-- | test/lld/recursive_safe_stack.wat.out | 88 | ||||
-rw-r--r-- | test/lld/safe_stack_standalone-wasm.wat.out | 88 | ||||
-rw-r--r-- | test/passes/stack-check_enable-mutable-globals.txt | 48 | ||||
-rw-r--r-- | test/passes/stack-check_enable-mutable-globals.wast | 7 |
6 files changed, 269 insertions, 89 deletions
diff --git a/src/passes/StackCheck.cpp b/src/passes/StackCheck.cpp index 4c797351b..183420fba 100644 --- a/src/passes/StackCheck.cpp +++ b/src/passes/StackCheck.cpp @@ -31,8 +31,16 @@ namespace wasm { +// The base is where the stack begins. As it goes down, that is the highest +// valid address. +static Name STACK_BASE("__stack_base"); +// The limit is the farthest it can grow to, which is the lowest valid address. static Name STACK_LIMIT("__stack_limit"); +// Old version, which sets the limit. +// TODO: remove this static Name SET_STACK_LIMIT("__set_stack_limit"); +// New version, which sets the base and the limit. +static Name SET_STACK_LIMITS("__set_stack_limits"); static void importStackOverflowHandler(Module& module, Name name) { ImportInfo info(module); @@ -55,34 +63,43 @@ static void addExportedFunction(Module& module, Function* function) { module.addExport(export_); } -static void generateSetStackLimitFunction(Module& module) { +static void generateSetStackLimitFunctions(Module& module) { Builder builder(module); - Function* function = + // One-parameter version + Function* limitFunc = builder.makeFunction(SET_STACK_LIMIT, Signature(Type::i32, Type::none), {}); LocalGet* getArg = builder.makeLocalGet(0, Type::i32); Expression* store = builder.makeGlobalSet(STACK_LIMIT, getArg); - function->body = store; - addExportedFunction(module, function); + limitFunc->body = store; + addExportedFunction(module, limitFunc); + // Two-parameter version + Function* limitsFunc = builder.makeFunction( + SET_STACK_LIMITS, Signature({Type::i32, Type::i32}, Type::none), {}); + LocalGet* getBase = builder.makeLocalGet(0, Type::i32); + Expression* storeBase = builder.makeGlobalSet(STACK_BASE, getBase); + LocalGet* getLimit = builder.makeLocalGet(1, Type::i32); + Expression* storeLimit = builder.makeGlobalSet(STACK_LIMIT, getLimit); + limitsFunc->body = builder.makeBlock({storeBase, storeLimit}); + addExportedFunction(module, limitsFunc); } -struct EnforceStackLimit : public WalkerPass<PostWalker<EnforceStackLimit>> { - EnforceStackLimit(Global* stackPointer, - Global* stackLimit, - Builder& builder, - Name handler) - : stackPointer(stackPointer), stackLimit(stackLimit), builder(builder), - handler(handler) {} +struct EnforceStackLimits : public WalkerPass<PostWalker<EnforceStackLimits>> { + EnforceStackLimits(Global* stackPointer, + Global* stackBase, + Global* stackLimit, + Builder& builder, + Name handler) + : stackPointer(stackPointer), stackBase(stackBase), stackLimit(stackLimit), + builder(builder), handler(handler) {} bool isFunctionParallel() override { return true; } Pass* create() override { - return new EnforceStackLimit(stackPointer, stackLimit, builder, handler); + return new EnforceStackLimits( + stackPointer, stackBase, stackLimit, builder, handler); } - Expression* stackBoundsCheck(Function* func, - Expression* value, - Global* stackPointer, - Global* stackLimit) { + Expression* stackBoundsCheck(Function* func, Expression* value) { // Add a local to store the value of the expression. We need the value // twice: once to check if it has overflowed, and again to assign to store // it. @@ -95,12 +112,18 @@ struct EnforceStackLimit : public WalkerPass<PostWalker<EnforceStackLimit>> { } else { handlerExpr = builder.makeUnreachable(); } - // (if (i32.lt_u (local.tee $newSP (...val...)) (global.get $__stack_limit)) + // If it is >= the base or <= the limit, then error. auto check = builder.makeIf( builder.makeBinary( - BinaryOp::LtUInt32, - builder.makeLocalTee(newSP, value, stackPointer->type), - builder.makeGlobalGet(stackLimit->name, stackLimit->type)), + BinaryOp::OrInt32, + builder.makeBinary( + BinaryOp::GtUInt32, + builder.makeLocalTee(newSP, value, stackPointer->type), + builder.makeGlobalGet(stackBase->name, stackBase->type)), + builder.makeBinary( + BinaryOp::LtUInt32, + builder.makeLocalGet(newSP, stackPointer->type), + builder.makeGlobalGet(stackLimit->name, stackLimit->type))), handlerExpr); // (global.set $__stack_pointer (local.get $newSP)) auto newSet = builder.makeGlobalSet( @@ -110,13 +133,13 @@ struct EnforceStackLimit : public WalkerPass<PostWalker<EnforceStackLimit>> { void visitGlobalSet(GlobalSet* curr) { if (getModule()->getGlobalOrNull(curr->name) == stackPointer) { - replaceCurrent( - stackBoundsCheck(getFunction(), curr->value, stackPointer, stackLimit)); + replaceCurrent(stackBoundsCheck(getFunction(), curr->value)); } } private: Global* stackPointer; + Global* stackBase; Global* stackLimit; Builder& builder; Name handler; @@ -139,6 +162,12 @@ struct StackCheck : public Pass { } Builder builder(*module); + Global* stackBase = builder.makeGlobal(STACK_BASE, + stackPointer->type, + builder.makeConst(int32_t(0)), + Builder::Mutable); + module->addGlobal(stackBase); + Global* stackLimit = builder.makeGlobal(STACK_LIMIT, stackPointer->type, builder.makeConst(int32_t(0)), @@ -146,9 +175,9 @@ struct StackCheck : public Pass { module->addGlobal(stackLimit); PassRunner innerRunner(module); - EnforceStackLimit(stackPointer, stackLimit, builder, handler) + EnforceStackLimits(stackPointer, stackBase, stackLimit, builder, handler) .run(&innerRunner, module); - generateSetStackLimitFunction(*module); + generateSetStackLimitFunctions(*module); } }; diff --git a/test/lld/basic_safe_stack.wat.out b/test/lld/basic_safe_stack.wat.out index a72f3d215..9e6a916da 100644 --- a/test/lld/basic_safe_stack.wat.out +++ b/test/lld/basic_safe_stack.wat.out @@ -2,11 +2,13 @@ (type $none_=>_none (func)) (type $i32_=>_none (func (param i32))) (type $i32_=>_i32 (func (param i32) (result i32))) + (type $i32_i32_=>_none (func (param i32 i32))) (import "env" "__handle_stack_overflow" (func $__handle_stack_overflow)) (memory $0 2) (table $0 1 1 funcref) (global $global$0 (mut i32) (i32.const 66112)) (global $global$1 i32 (i32.const 568)) + (global $__stack_base (mut i32) (i32.const 0)) (global $__stack_limit (mut i32) (i32.const 0)) (export "memory" (memory $0)) (export "__wasm_call_ctors" (func $__wasm_call_ctors)) @@ -15,6 +17,7 @@ (export "main" (func $main)) (export "__data_end" (global $global$1)) (export "__set_stack_limit" (func $__set_stack_limit)) + (export "__set_stack_limits" (func $__set_stack_limits)) (export "__growWasmMemory" (func $__growWasmMemory)) (func $__wasm_call_ctors (nop) @@ -22,11 +25,17 @@ (func $stackRestore (param $0 i32) (local $1 i32) (if - (i32.lt_u - (local.tee $1 - (local.get $0) + (i32.or + (i32.gt_u + (local.tee $1 + (local.get $0) + ) + (global.get $__stack_base) + ) + (i32.lt_u + (local.get $1) + (global.get $__stack_limit) ) - (global.get $__stack_limit) ) (call $__handle_stack_overflow) ) @@ -40,19 +49,25 @@ (local $3 i32) (block (if - (i32.lt_u - (local.tee $3 - (local.tee $1 - (i32.and - (i32.sub - (global.get $global$0) - (local.get $0) + (i32.or + (i32.gt_u + (local.tee $3 + (local.tee $1 + (i32.and + (i32.sub + (global.get $global$0) + (local.get $0) + ) + (i32.const -16) ) - (i32.const -16) ) ) + (global.get $__stack_base) + ) + (i32.lt_u + (local.get $3) + (global.get $__stack_limit) ) - (global.get $__stack_limit) ) (call $__handle_stack_overflow) ) @@ -70,6 +85,14 @@ (local.get $0) ) ) + (func $__set_stack_limits (param $0 i32) (param $1 i32) + (global.set $__stack_base + (local.get $0) + ) + (global.set $__stack_limit + (local.get $1) + ) + ) (func $__growWasmMemory (param $newSize i32) (result i32) (memory.grow (local.get $newSize) @@ -95,6 +118,7 @@ "stackAlloc", "main", "__set_stack_limit", + "__set_stack_limits", "__growWasmMemory" ], "namedGlobals": { diff --git a/test/lld/recursive_safe_stack.wat.out b/test/lld/recursive_safe_stack.wat.out index b84820e09..cceca00ab 100644 --- a/test/lld/recursive_safe_stack.wat.out +++ b/test/lld/recursive_safe_stack.wat.out @@ -2,6 +2,7 @@ (type $i32_i32_=>_i32 (func (param i32 i32) (result i32))) (type $none_=>_none (func)) (type $i32_=>_none (func (param i32))) + (type $i32_i32_=>_none (func (param i32 i32))) (type $none_=>_i32 (func (result i32))) (type $i32_=>_i32 (func (param i32) (result i32))) (import "env" "printf" (func $printf (param i32 i32) (result i32))) @@ -12,6 +13,7 @@ (global $global$0 (mut i32) (i32.const 66128)) (global $global$1 i32 (i32.const 66128)) (global $global$2 i32 (i32.const 587)) + (global $__stack_base (mut i32) (i32.const 0)) (global $__stack_limit (mut i32) (i32.const 0)) (export "memory" (memory $0)) (export "__wasm_call_ctors" (func $__wasm_call_ctors)) @@ -19,6 +21,7 @@ (export "__data_end" (global $global$2)) (export "main" (func $main)) (export "__set_stack_limit" (func $__set_stack_limit)) + (export "__set_stack_limits" (func $__set_stack_limits)) (export "__growWasmMemory" (func $__growWasmMemory)) (func $__wasm_call_ctors (nop) @@ -29,16 +32,22 @@ (local $4 i32) (block (if - (i32.lt_u - (local.tee $3 - (local.tee $2 - (i32.sub - (global.get $global$0) - (i32.const 16) + (i32.or + (i32.gt_u + (local.tee $3 + (local.tee $2 + (i32.sub + (global.get $global$0) + (i32.const 16) + ) ) ) + (global.get $__stack_base) + ) + (i32.lt_u + (local.get $3) + (global.get $__stack_limit) ) - (global.get $__stack_limit) ) (call $__handle_stack_overflow) ) @@ -62,14 +71,20 @@ ) (block (if - (i32.lt_u - (local.tee $4 - (i32.add - (local.get $2) - (i32.const 16) + (i32.or + (i32.gt_u + (local.tee $4 + (i32.add + (local.get $2) + (i32.const 16) + ) ) + (global.get $__stack_base) + ) + (i32.lt_u + (local.get $4) + (global.get $__stack_limit) ) - (global.get $__stack_limit) ) (call $__handle_stack_overflow) ) @@ -88,16 +103,22 @@ (local $2 i32) (block (if - (i32.lt_u - (local.tee $1 - (local.tee $0 - (i32.sub - (global.get $global$0) - (i32.const 16) + (i32.or + (i32.gt_u + (local.tee $1 + (local.tee $0 + (i32.sub + (global.get $global$0) + (i32.const 16) + ) ) ) + (global.get $__stack_base) + ) + (i32.lt_u + (local.get $1) + (global.get $__stack_limit) ) - (global.get $__stack_limit) ) (call $__handle_stack_overflow) ) @@ -120,14 +141,20 @@ ) (block (if - (i32.lt_u - (local.tee $2 - (i32.add - (local.get $0) - (i32.const 16) + (i32.or + (i32.gt_u + (local.tee $2 + (i32.add + (local.get $0) + (i32.const 16) + ) ) + (global.get $__stack_base) + ) + (i32.lt_u + (local.get $2) + (global.get $__stack_limit) ) - (global.get $__stack_limit) ) (call $__handle_stack_overflow) ) @@ -145,6 +172,14 @@ (local.get $0) ) ) + (func $__set_stack_limits (param $0 i32) (param $1 i32) + (global.set $__stack_base + (local.get $0) + ) + (global.set $__stack_limit + (local.get $1) + ) + ) (func $__growWasmMemory (param $newSize i32) (result i32) (memory.grow (local.get $newSize) @@ -169,6 +204,7 @@ "__wasm_call_ctors", "main", "__set_stack_limit", + "__set_stack_limits", "__growWasmMemory" ], "namedGlobals": { diff --git a/test/lld/safe_stack_standalone-wasm.wat.out b/test/lld/safe_stack_standalone-wasm.wat.out index 48ea2be68..51b88661b 100644 --- a/test/lld/safe_stack_standalone-wasm.wat.out +++ b/test/lld/safe_stack_standalone-wasm.wat.out @@ -2,6 +2,7 @@ (type $i32_i32_=>_i32 (func (param i32 i32) (result i32))) (type $none_=>_none (func)) (type $i32_=>_none (func (param i32))) + (type $i32_i32_=>_none (func (param i32 i32))) (type $none_=>_i32 (func (result i32))) (type $i32_=>_i32 (func (param i32) (result i32))) (import "env" "printf" (func $printf (param i32 i32) (result i32))) @@ -11,6 +12,7 @@ (global $global$0 (mut i32) (i32.const 66128)) (global $global$1 i32 (i32.const 66128)) (global $global$2 i32 (i32.const 587)) + (global $__stack_base (mut i32) (i32.const 0)) (global $__stack_limit (mut i32) (i32.const 0)) (export "memory" (memory $0)) (export "__wasm_call_ctors" (func $__wasm_call_ctors)) @@ -18,6 +20,7 @@ (export "__data_end" (global $global$2)) (export "main" (func $main)) (export "__set_stack_limit" (func $__set_stack_limit)) + (export "__set_stack_limits" (func $__set_stack_limits)) (export "__growWasmMemory" (func $__growWasmMemory)) (func $__wasm_call_ctors (nop) @@ -28,16 +31,22 @@ (local $4 i32) (block (if - (i32.lt_u - (local.tee $3 - (local.tee $2 - (i32.sub - (global.get $global$0) - (i32.const 16) + (i32.or + (i32.gt_u + (local.tee $3 + (local.tee $2 + (i32.sub + (global.get $global$0) + (i32.const 16) + ) ) ) + (global.get $__stack_base) + ) + (i32.lt_u + (local.get $3) + (global.get $__stack_limit) ) - (global.get $__stack_limit) ) (unreachable) ) @@ -61,14 +70,20 @@ ) (block (if - (i32.lt_u - (local.tee $4 - (i32.add - (local.get $2) - (i32.const 16) + (i32.or + (i32.gt_u + (local.tee $4 + (i32.add + (local.get $2) + (i32.const 16) + ) ) + (global.get $__stack_base) + ) + (i32.lt_u + (local.get $4) + (global.get $__stack_limit) ) - (global.get $__stack_limit) ) (unreachable) ) @@ -87,16 +102,22 @@ (local $2 i32) (block (if - (i32.lt_u - (local.tee $1 - (local.tee $0 - (i32.sub - (global.get $global$0) - (i32.const 16) + (i32.or + (i32.gt_u + (local.tee $1 + (local.tee $0 + (i32.sub + (global.get $global$0) + (i32.const 16) + ) ) ) + (global.get $__stack_base) + ) + (i32.lt_u + (local.get $1) + (global.get $__stack_limit) ) - (global.get $__stack_limit) ) (unreachable) ) @@ -119,14 +140,20 @@ ) (block (if - (i32.lt_u - (local.tee $2 - (i32.add - (local.get $0) - (i32.const 16) + (i32.or + (i32.gt_u + (local.tee $2 + (i32.add + (local.get $0) + (i32.const 16) + ) ) + (global.get $__stack_base) + ) + (i32.lt_u + (local.get $2) + (global.get $__stack_limit) ) - (global.get $__stack_limit) ) (unreachable) ) @@ -144,6 +171,14 @@ (local.get $0) ) ) + (func $__set_stack_limits (param $0 i32) (param $1 i32) + (global.set $__stack_base + (local.get $0) + ) + (global.set $__stack_limit + (local.get $1) + ) + ) (func $__growWasmMemory (param $newSize i32) (result i32) (memory.grow (local.get $newSize) @@ -167,6 +202,7 @@ "__wasm_call_ctors", "main", "__set_stack_limit", + "__set_stack_limits", "__growWasmMemory" ], "namedGlobals": { diff --git a/test/passes/stack-check_enable-mutable-globals.txt b/test/passes/stack-check_enable-mutable-globals.txt new file mode 100644 index 000000000..52ee091ef --- /dev/null +++ b/test/passes/stack-check_enable-mutable-globals.txt @@ -0,0 +1,48 @@ +(module + (type $i32_=>_none (func (param i32))) + (type $i32_i32_=>_none (func (param i32 i32))) + (type $none_=>_i32 (func (result i32))) + (import "env" "__stack_pointer" (global $sp (mut i32))) + (global $__stack_base (mut i32) (i32.const 0)) + (global $__stack_limit (mut i32) (i32.const 0)) + (export "use_stack" (func $0)) + (export "__set_stack_limit" (func $__set_stack_limit)) + (export "__set_stack_limits" (func $__set_stack_limits)) + (func $0 (result i32) + (local $0 i32) + (block + (if + (i32.or + (i32.gt_u + (local.tee $0 + (i32.const 42) + ) + (global.get $__stack_base) + ) + (i32.lt_u + (local.get $0) + (global.get $__stack_limit) + ) + ) + (unreachable) + ) + (global.set $sp + (local.get $0) + ) + ) + (global.get $sp) + ) + (func $__set_stack_limit (param $0 i32) + (global.set $__stack_limit + (local.get $0) + ) + ) + (func $__set_stack_limits (param $0 i32) (param $1 i32) + (global.set $__stack_base + (local.get $0) + ) + (global.set $__stack_limit + (local.get $1) + ) + ) +) diff --git a/test/passes/stack-check_enable-mutable-globals.wast b/test/passes/stack-check_enable-mutable-globals.wast new file mode 100644 index 000000000..3028039fd --- /dev/null +++ b/test/passes/stack-check_enable-mutable-globals.wast @@ -0,0 +1,7 @@ +(module + (import "env" "__stack_pointer" (global $sp (mut i32))) + (func "use_stack" (result i32) + (global.set $sp (i32.const 42)) + (global.get $sp) + ) +) |