diff options
-rw-r--r-- | src/passes/StackCheck.cpp | 13 | ||||
-rw-r--r-- | test/lld/basic_safe_stack.wat.out | 10 | ||||
-rw-r--r-- | test/lld/recursive_safe_stack.wat.out | 19 |
3 files changed, 30 insertions, 12 deletions
diff --git a/src/passes/StackCheck.cpp b/src/passes/StackCheck.cpp index e7fc3c6b3..07cf46aee 100644 --- a/src/passes/StackCheck.cpp +++ b/src/passes/StackCheck.cpp @@ -35,11 +35,12 @@ namespace wasm { // Exported function to set the base and the limit. static Name SET_STACK_LIMITS("__set_stack_limits"); -static void importStackOverflowHandler(Module& module, Name name) { +static void +importStackOverflowHandler(Module& module, Name name, Signature sig) { ImportInfo info(module); if (!info.getImportedFunction(ENV, name)) { - auto import = Builder::makeFunction(name, Signature(), {}); + auto import = Builder::makeFunction(name, sig, {}); import->module = ENV; import->base = name; module.addFunction(std::move(import)); @@ -79,7 +80,10 @@ struct EnforceStackLimits : public WalkerPass<PostWalker<EnforceStackLimits>> { // Otherwise, just trap. Expression* handlerExpr; if (handler.is()) { - handlerExpr = builder.makeCall(handler, {}, Type::none); + handlerExpr = + builder.makeCall(handler, + {builder.makeLocalGet(newSP, stackPointer->type)}, + stackPointer->type); } else { handlerExpr = builder.makeUnreachable(); } @@ -133,7 +137,8 @@ struct StackCheck : public Pass { runner->options.getArgumentOrDefault("stack-check-handler", ""); if (handlerName != "") { handler = handlerName; - importStackOverflowHandler(*module, handler); + importStackOverflowHandler( + *module, handler, Signature({stackPointer->type}, Type::none)); } Builder builder(*module); diff --git a/test/lld/basic_safe_stack.wat.out b/test/lld/basic_safe_stack.wat.out index d6aca5760..6e1b83712 100644 --- a/test/lld/basic_safe_stack.wat.out +++ b/test/lld/basic_safe_stack.wat.out @@ -3,7 +3,7 @@ (type $i32_=>_none (func (param i32))) (type $i32_=>_i32 (func (param i32) (result i32))) (type $i32_i32_=>_none (func (param i32 i32))) - (import "env" "__handle_stack_overflow" (func $__handle_stack_overflow)) + (import "env" "__handle_stack_overflow" (func $__handle_stack_overflow (param i32))) (global $__stack_pointer (mut i32) (i32.const 66112)) (global $__stack_base (mut i32) (i32.const 0)) (global $__stack_limit (mut i32) (i32.const 0)) @@ -32,7 +32,9 @@ (global.get $__stack_limit) ) ) - (call $__handle_stack_overflow) + (call $__handle_stack_overflow + (local.get $1) + ) ) (global.set $__stack_pointer (local.get $1) @@ -64,7 +66,9 @@ (global.get $__stack_limit) ) ) - (call $__handle_stack_overflow) + (call $__handle_stack_overflow + (local.get $3) + ) ) (global.set $__stack_pointer (local.get $3) diff --git a/test/lld/recursive_safe_stack.wat.out b/test/lld/recursive_safe_stack.wat.out index 3272731b1..d6cd8f63a 100644 --- a/test/lld/recursive_safe_stack.wat.out +++ b/test/lld/recursive_safe_stack.wat.out @@ -2,9 +2,10 @@ (type $0 (func (param i32 i32) (result i32))) (type $1 (func)) (type $2 (func (result i32))) + (type $i32_=>_none (func (param i32))) (type $i32_i32_=>_none (func (param i32 i32))) (import "env" "printf" (func $printf (param i32 i32) (result i32))) - (import "env" "__handle_stack_overflow" (func $__handle_stack_overflow)) + (import "env" "__handle_stack_overflow" (func $__handle_stack_overflow (param i32))) (global $global$0 (mut i32) (i32.const 66128)) (global $global$1 i32 (i32.const 66128)) (global $global$2 i32 (i32.const 587)) @@ -45,7 +46,9 @@ (global.get $__stack_limit) ) ) - (call $__handle_stack_overflow) + (call $__handle_stack_overflow + (local.get $3) + ) ) (global.set $global$0 (local.get $3) @@ -82,7 +85,9 @@ (global.get $__stack_limit) ) ) - (call $__handle_stack_overflow) + (call $__handle_stack_overflow + (local.get $4) + ) ) (global.set $global$0 (local.get $4) @@ -116,7 +121,9 @@ (global.get $__stack_limit) ) ) - (call $__handle_stack_overflow) + (call $__handle_stack_overflow + (local.get $1) + ) ) (global.set $global$0 (local.get $1) @@ -152,7 +159,9 @@ (global.get $__stack_limit) ) ) - (call $__handle_stack_overflow) + (call $__handle_stack_overflow + (local.get $2) + ) ) (global.set $global$0 (local.get $2) |