summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/passes/StackCheck.cpp13
-rw-r--r--test/lld/basic_safe_stack.wat.out10
-rw-r--r--test/lld/recursive_safe_stack.wat.out19
3 files changed, 30 insertions, 12 deletions
diff --git a/src/passes/StackCheck.cpp b/src/passes/StackCheck.cpp
index e7fc3c6b3..07cf46aee 100644
--- a/src/passes/StackCheck.cpp
+++ b/src/passes/StackCheck.cpp
@@ -35,11 +35,12 @@ namespace wasm {
// Exported function to set the base and the limit.
static Name SET_STACK_LIMITS("__set_stack_limits");
-static void importStackOverflowHandler(Module& module, Name name) {
+static void
+importStackOverflowHandler(Module& module, Name name, Signature sig) {
ImportInfo info(module);
if (!info.getImportedFunction(ENV, name)) {
- auto import = Builder::makeFunction(name, Signature(), {});
+ auto import = Builder::makeFunction(name, sig, {});
import->module = ENV;
import->base = name;
module.addFunction(std::move(import));
@@ -79,7 +80,10 @@ struct EnforceStackLimits : public WalkerPass<PostWalker<EnforceStackLimits>> {
// Otherwise, just trap.
Expression* handlerExpr;
if (handler.is()) {
- handlerExpr = builder.makeCall(handler, {}, Type::none);
+ handlerExpr =
+ builder.makeCall(handler,
+ {builder.makeLocalGet(newSP, stackPointer->type)},
+ stackPointer->type);
} else {
handlerExpr = builder.makeUnreachable();
}
@@ -133,7 +137,8 @@ struct StackCheck : public Pass {
runner->options.getArgumentOrDefault("stack-check-handler", "");
if (handlerName != "") {
handler = handlerName;
- importStackOverflowHandler(*module, handler);
+ importStackOverflowHandler(
+ *module, handler, Signature({stackPointer->type}, Type::none));
}
Builder builder(*module);
diff --git a/test/lld/basic_safe_stack.wat.out b/test/lld/basic_safe_stack.wat.out
index d6aca5760..6e1b83712 100644
--- a/test/lld/basic_safe_stack.wat.out
+++ b/test/lld/basic_safe_stack.wat.out
@@ -3,7 +3,7 @@
(type $i32_=>_none (func (param i32)))
(type $i32_=>_i32 (func (param i32) (result i32)))
(type $i32_i32_=>_none (func (param i32 i32)))
- (import "env" "__handle_stack_overflow" (func $__handle_stack_overflow))
+ (import "env" "__handle_stack_overflow" (func $__handle_stack_overflow (param i32)))
(global $__stack_pointer (mut i32) (i32.const 66112))
(global $__stack_base (mut i32) (i32.const 0))
(global $__stack_limit (mut i32) (i32.const 0))
@@ -32,7 +32,9 @@
(global.get $__stack_limit)
)
)
- (call $__handle_stack_overflow)
+ (call $__handle_stack_overflow
+ (local.get $1)
+ )
)
(global.set $__stack_pointer
(local.get $1)
@@ -64,7 +66,9 @@
(global.get $__stack_limit)
)
)
- (call $__handle_stack_overflow)
+ (call $__handle_stack_overflow
+ (local.get $3)
+ )
)
(global.set $__stack_pointer
(local.get $3)
diff --git a/test/lld/recursive_safe_stack.wat.out b/test/lld/recursive_safe_stack.wat.out
index 3272731b1..d6cd8f63a 100644
--- a/test/lld/recursive_safe_stack.wat.out
+++ b/test/lld/recursive_safe_stack.wat.out
@@ -2,9 +2,10 @@
(type $0 (func (param i32 i32) (result i32)))
(type $1 (func))
(type $2 (func (result i32)))
+ (type $i32_=>_none (func (param i32)))
(type $i32_i32_=>_none (func (param i32 i32)))
(import "env" "printf" (func $printf (param i32 i32) (result i32)))
- (import "env" "__handle_stack_overflow" (func $__handle_stack_overflow))
+ (import "env" "__handle_stack_overflow" (func $__handle_stack_overflow (param i32)))
(global $global$0 (mut i32) (i32.const 66128))
(global $global$1 i32 (i32.const 66128))
(global $global$2 i32 (i32.const 587))
@@ -45,7 +46,9 @@
(global.get $__stack_limit)
)
)
- (call $__handle_stack_overflow)
+ (call $__handle_stack_overflow
+ (local.get $3)
+ )
)
(global.set $global$0
(local.get $3)
@@ -82,7 +85,9 @@
(global.get $__stack_limit)
)
)
- (call $__handle_stack_overflow)
+ (call $__handle_stack_overflow
+ (local.get $4)
+ )
)
(global.set $global$0
(local.get $4)
@@ -116,7 +121,9 @@
(global.get $__stack_limit)
)
)
- (call $__handle_stack_overflow)
+ (call $__handle_stack_overflow
+ (local.get $1)
+ )
)
(global.set $global$0
(local.get $1)
@@ -152,7 +159,9 @@
(global.get $__stack_limit)
)
)
- (call $__handle_stack_overflow)
+ (call $__handle_stack_overflow
+ (local.get $2)
+ )
)
(global.set $global$0
(local.get $2)