summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGuanzhong Chen <gzchen@google.com>2019-08-02 14:15:58 -0700
committerGitHub <noreply@github.com>2019-08-02 14:15:58 -0700
commit4f0d960ef686dff7d635cb6051d07111e6e27a27 (patch)
treefec00a82eaf658f01f13f7d831b47071bf08a9e6
parent8d4d43f0f239877b10789d3a85deb92f1927bc2e (diff)
downloadbinaryen-4f0d960ef686dff7d635cb6051d07111e6e27a27.tar.gz
binaryen-4f0d960ef686dff7d635cb6051d07111e6e27a27.tar.bz2
binaryen-4f0d960ef686dff7d635cb6051d07111e6e27a27.zip
Implement --check-stack-overflow flag for wasm-emscripten-finalize (#2278)
-rwxr-xr-xscripts/test/lld.py10
-rw-r--r--src/tools/wasm-emscripten-finalize.cpp12
-rw-r--r--src/wasm-emscripten.h6
-rw-r--r--src/wasm/wasm-emscripten.cpp121
-rw-r--r--test/lld/recursive_safe_stack.wast90
-rw-r--r--test/lld/recursive_safe_stack.wast.out244
6 files changed, 475 insertions, 8 deletions
diff --git a/scripts/test/lld.py b/scripts/test/lld.py
index fc68a4620..1a553506f 100755
--- a/scripts/test/lld.py
+++ b/scripts/test/lld.py
@@ -22,10 +22,12 @@ from .shared import (
def args_for_finalize(filename):
- if 'shared' in filename:
- return ['--side-module']
- else:
- return ['--global-base=568']
+ if 'safe_stack' in filename:
+ return ['--check-stack-overflow', '--global-base=568']
+ elif 'shared' in filename:
+ return ['--side-module']
+ else:
+ return ['--global-base=568']
def test_wasm_emscripten_finalize():
diff --git a/src/tools/wasm-emscripten-finalize.cpp b/src/tools/wasm-emscripten-finalize.cpp
index 56a6c3b39..f5afb3e24 100644
--- a/src/tools/wasm-emscripten-finalize.cpp
+++ b/src/tools/wasm-emscripten-finalize.cpp
@@ -48,6 +48,7 @@ int main(int argc, const char* argv[]) {
bool debugInfo = false;
bool isSideModule = false;
bool legalizeJavaScriptFFI = true;
+ bool checkStackOverflow = false;
uint64_t globalBase = INVALID_BASE;
ToolOptions options("wasm-emscripten-finalize",
"Performs Emscripten-specific transforms on .wasm files");
@@ -127,6 +128,13 @@ int main(int argc, const char* argv[]) {
[&dataSegmentFile](Options* o, const std::string& argument) {
dataSegmentFile = argument;
})
+ .add("--check-stack-overflow",
+ "",
+ "Check for stack overflows every time the stack is extended",
+ Options::Arguments::Zero,
+ [&checkStackOverflow](Options* o, const std::string&) {
+ checkStackOverflow = true;
+ })
.add_positional("INFILE",
Options::Arguments::One,
[&infile](Options* o, const std::string& argument) {
@@ -200,6 +208,10 @@ int main(int argc, const char* argv[]) {
}
wasm.updateMaps();
+ if (checkStackOverflow && !isSideModule) {
+ generator.enforceStackLimit();
+ }
+
if (isSideModule) {
generator.replaceStackPointerGlobal();
generator.generatePostInstantiateFunction();
diff --git a/src/wasm-emscripten.h b/src/wasm-emscripten.h
index 0841708f2..1758aa23a 100644
--- a/src/wasm-emscripten.h
+++ b/src/wasm-emscripten.h
@@ -54,6 +54,8 @@ public:
void fixInvokeFunctionNames();
+ void enforceStackLimit();
+
// Emits the data segments to a file. The file contains data from address base
// onwards (we must pass in base, as we can't tell it from the wasm - the
// first segment may start after a run of zeros, but we need those zeros in
@@ -71,11 +73,13 @@ private:
Global* getStackPointerGlobal();
Expression* generateLoadStackPointer();
- Expression* generateStoreStackPointer(Expression* value);
+ Expression* generateStoreStackPointer(Function* func, Expression* value);
void generateDynCallThunk(std::string sig);
void generateStackSaveFunction();
void generateStackAllocFunction();
void generateStackRestoreFunction();
+ void generateSetStackLimitFunction();
+ Name importStackOverflowHandler();
};
} // namespace wasm
diff --git a/src/wasm/wasm-emscripten.cpp b/src/wasm/wasm-emscripten.cpp
index 3ee3e4424..c21016baa 100644
--- a/src/wasm/wasm-emscripten.cpp
+++ b/src/wasm/wasm-emscripten.cpp
@@ -37,8 +37,11 @@ static Name STACK_SAVE("stackSave");
static Name STACK_RESTORE("stackRestore");
static Name STACK_ALLOC("stackAlloc");
static Name STACK_INIT("stack$init");
+static Name STACK_LIMIT("__stack_limit");
+static Name SET_STACK_LIMIT("__set_stack_limit");
static Name POST_INSTANTIATE("__post_instantiate");
static Name ASSIGN_GOT_ENTIRES("__assign_got_enties");
+static Name STACK_OVERFLOW_IMPORT("__handle_stack_overflow");
void addExportedFunction(Module& wasm, Function* function) {
wasm.addFunction(function);
@@ -92,8 +95,32 @@ Expression* EmscriptenGlueGenerator::generateLoadStackPointer() {
return builder.makeGlobalGet(stackPointer->name, i32);
}
+inline Expression* stackBoundsCheck(Builder& builder,
+ Function* func,
+ Expression* value,
+ Global* stackPointer,
+ Global* stackLimit,
+ Name handler) {
+ // Add a local to store the value of the expression. We need the value twice:
+ // once to check if it has overflowed, and again to assign to store it.
+ auto newSP = Builder::addVar(func, stackPointer->type);
+ // (if (i32.lt_u (local.tee $newSP (...value...)) (global.get $__stack_limit))
+ // (call $handler))
+ auto check =
+ builder.makeIf(builder.makeBinary(
+ BinaryOp::LtUInt32,
+ builder.makeLocalTee(newSP, value),
+ builder.makeGlobalGet(stackLimit->name, stackLimit->type)),
+ builder.makeCall(handler, {}, none));
+ // (global.set $__stack_pointer (local.get $newSP))
+ auto newSet = builder.makeGlobalSet(
+ stackPointer->name, builder.makeLocalGet(newSP, stackPointer->type));
+ return builder.blockify(check, newSet);
+}
+
Expression*
-EmscriptenGlueGenerator::generateStoreStackPointer(Expression* value) {
+EmscriptenGlueGenerator::generateStoreStackPointer(Function* func,
+ Expression* value) {
if (!useStackPointerGlobal) {
return builder.makeStore(
/* bytes =*/4,
@@ -107,6 +134,14 @@ EmscriptenGlueGenerator::generateStoreStackPointer(Expression* value) {
if (!stackPointer) {
Fatal() << "stack pointer global not found";
}
+ if (auto* stackLimit = wasm.getGlobalOrNull(STACK_LIMIT)) {
+ return stackBoundsCheck(builder,
+ func,
+ value,
+ stackPointer,
+ stackLimit,
+ importStackOverflowHandler());
+ }
return builder.makeGlobalSet(stackPointer->name, value);
}
@@ -132,7 +167,7 @@ void EmscriptenGlueGenerator::generateStackAllocFunction() {
Const* subConst = builder.makeConst(Literal(~bitMask));
Binary* maskedSub = builder.makeBinary(AndInt32, sub, subConst);
LocalSet* teeStackLocal = builder.makeLocalTee(1, maskedSub);
- Expression* storeStack = generateStoreStackPointer(teeStackLocal);
+ Expression* storeStack = generateStoreStackPointer(function, teeStackLocal);
Block* block = builder.makeBlock();
block->list.push_back(storeStack);
@@ -149,7 +184,7 @@ void EmscriptenGlueGenerator::generateStackRestoreFunction() {
Function* function =
builder.makeFunction(STACK_RESTORE, std::move(params), none, {});
LocalGet* getArg = builder.makeLocalGet(0, i32);
- Expression* store = generateStoreStackPointer(getArg);
+ Expression* store = generateStoreStackPointer(function, getArg);
function->body = store;
@@ -444,6 +479,86 @@ void EmscriptenGlueGenerator::replaceStackPointerGlobal() {
wasm.removeGlobal(stackPointer->name);
}
+struct StackLimitEnforcer : public WalkerPass<PostWalker<StackLimitEnforcer>> {
+ StackLimitEnforcer(Global* stackPointer,
+ Global* stackLimit,
+ Builder& builder,
+ Name handler)
+ : stackPointer(stackPointer), stackLimit(stackLimit), builder(builder),
+ handler(handler) {}
+
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override {
+ return new StackLimitEnforcer(stackPointer, stackLimit, builder, handler);
+ }
+
+ void visitGlobalSet(GlobalSet* curr) {
+ if (getModule()->getGlobalOrNull(curr->name) == stackPointer) {
+ replaceCurrent(stackBoundsCheck(builder,
+ getFunction(),
+ curr->value,
+ stackPointer,
+ stackLimit,
+ handler));
+ }
+ }
+
+private:
+ Global* stackPointer;
+ Global* stackLimit;
+ Builder& builder;
+ Name handler;
+};
+
+void EmscriptenGlueGenerator::enforceStackLimit() {
+ Global* stackPointer = getStackPointerGlobal();
+ if (!stackPointer) {
+ return;
+ }
+
+ auto* stackLimit = builder.makeGlobal(STACK_LIMIT,
+ stackPointer->type,
+ builder.makeConst(Literal(0)),
+ Builder::Mutable);
+ wasm.addGlobal(stackLimit);
+
+ auto handler = importStackOverflowHandler();
+
+ StackLimitEnforcer walker(stackPointer, stackLimit, builder, handler);
+ PassRunner runner(&wasm);
+ walker.run(&runner, &wasm);
+
+ generateSetStackLimitFunction();
+}
+
+void EmscriptenGlueGenerator::generateSetStackLimitFunction() {
+ Function* function =
+ builder.makeFunction(SET_STACK_LIMIT, std::vector<Type>({i32}), none, {});
+ LocalGet* getArg = builder.makeLocalGet(0, i32);
+ Expression* store = builder.makeGlobalSet(STACK_LIMIT, getArg);
+ function->body = store;
+ addExportedFunction(wasm, function);
+}
+
+Name EmscriptenGlueGenerator::importStackOverflowHandler() {
+ ImportInfo info(wasm);
+
+ if (auto* existing = info.getImportedFunction(ENV, STACK_OVERFLOW_IMPORT)) {
+ return existing->name;
+ } else {
+ auto* import = new Function;
+ import->name = STACK_OVERFLOW_IMPORT;
+ import->module = ENV;
+ import->base = STACK_OVERFLOW_IMPORT;
+ auto* functionType = ensureFunctionType("v", &wasm);
+ import->type = functionType->name;
+ FunctionTypeUtils::fillFunction(import, functionType);
+ wasm.addFunction(import);
+ return STACK_OVERFLOW_IMPORT;
+ }
+}
+
const Address UNKNOWN_OFFSET(uint32_t(-1));
std::vector<Address> getSegmentOffsets(Module& wasm) {
diff --git a/test/lld/recursive_safe_stack.wast b/test/lld/recursive_safe_stack.wast
new file mode 100644
index 000000000..67f7f3914
--- /dev/null
+++ b/test/lld/recursive_safe_stack.wast
@@ -0,0 +1,90 @@
+(module
+ (type $0 (func (param i32 i32) (result i32)))
+ (type $1 (func))
+ (type $2 (func (result i32)))
+ (import "env" "printf" (func $printf (param i32 i32) (result i32)))
+ (memory $0 2)
+ (data (i32.const 568) "%d:%d\n\00Result: %d\n\00")
+ (table $0 1 1 funcref)
+ (global $global$0 (mut i32) (i32.const 66128))
+ (global $global$1 i32 (i32.const 66128))
+ (global $global$2 i32 (i32.const 587))
+ (export "memory" (memory $0))
+ (export "__wasm_call_ctors" (func $__wasm_call_ctors))
+ (export "__heap_base" (global $global$1))
+ (export "__data_end" (global $global$2))
+ (export "main" (func $main))
+ (func $__wasm_call_ctors (; 1 ;) (type $1)
+ )
+ (func $foo (; 2 ;) (type $0) (param $0 i32) (param $1 i32) (result i32)
+ (local $2 i32)
+ (global.set $global$0
+ (local.tee $2
+ (i32.sub
+ (global.get $global$0)
+ (i32.const 16)
+ )
+ )
+ )
+ (i32.store offset=4
+ (local.get $2)
+ (local.get $1)
+ )
+ (i32.store
+ (local.get $2)
+ (local.get $0)
+ )
+ (drop
+ (call $printf
+ (i32.const 568)
+ (local.get $2)
+ )
+ )
+ (global.set $global$0
+ (i32.add
+ (local.get $2)
+ (i32.const 16)
+ )
+ )
+ (i32.add
+ (local.get $1)
+ (local.get $0)
+ )
+ )
+ (func $__original_main (; 3 ;) (type $2) (result i32)
+ (local $0 i32)
+ (global.set $global$0
+ (local.tee $0
+ (i32.sub
+ (global.get $global$0)
+ (i32.const 16)
+ )
+ )
+ )
+ (i32.store
+ (local.get $0)
+ (call $foo
+ (i32.const 1)
+ (i32.const 2)
+ )
+ )
+ (drop
+ (call $printf
+ (i32.const 575)
+ (local.get $0)
+ )
+ )
+ (global.set $global$0
+ (i32.add
+ (local.get $0)
+ (i32.const 16)
+ )
+ )
+ (i32.const 0)
+ )
+ (func $main (; 4 ;) (type $0) (param $0 i32) (param $1 i32) (result i32)
+ (call $__original_main)
+ )
+ ;; custom section "producers", size 111
+)
+
diff --git a/test/lld/recursive_safe_stack.wast.out b/test/lld/recursive_safe_stack.wast.out
new file mode 100644
index 000000000..7253a28fb
--- /dev/null
+++ b/test/lld/recursive_safe_stack.wast.out
@@ -0,0 +1,244 @@
+(module
+ (type $0 (func (param i32 i32) (result i32)))
+ (type $1 (func))
+ (type $2 (func (result i32)))
+ (type $FUNCSIG$v (func))
+ (import "env" "printf" (func $printf (param i32 i32) (result i32)))
+ (import "env" "__handle_stack_overflow" (func $__handle_stack_overflow))
+ (memory $0 2)
+ (data (i32.const 568) "%d:%d\n\00Result: %d\n\00")
+ (table $0 1 1 funcref)
+ (global $global$0 (mut i32) (i32.const 66128))
+ (global $global$1 i32 (i32.const 66128))
+ (global $global$2 i32 (i32.const 587))
+ (global $__stack_limit (mut i32) (i32.const 0))
+ (export "memory" (memory $0))
+ (export "__wasm_call_ctors" (func $__wasm_call_ctors))
+ (export "__heap_base" (global $global$1))
+ (export "__data_end" (global $global$2))
+ (export "main" (func $main))
+ (export "__set_stack_limit" (func $__set_stack_limit))
+ (export "stackSave" (func $stackSave))
+ (export "stackAlloc" (func $stackAlloc))
+ (export "stackRestore" (func $stackRestore))
+ (export "__growWasmMemory" (func $__growWasmMemory))
+ (func $__wasm_call_ctors (; 2 ;) (type $1)
+ (nop)
+ )
+ (func $foo (; 3 ;) (type $0) (param $0 i32) (param $1 i32) (result i32)
+ (local $2 i32)
+ (local $3 i32)
+ (local $4 i32)
+ (block
+ (if
+ (i32.lt_u
+ (local.tee $3
+ (local.tee $2
+ (i32.sub
+ (global.get $global$0)
+ (i32.const 16)
+ )
+ )
+ )
+ (global.get $__stack_limit)
+ )
+ (call $__handle_stack_overflow)
+ )
+ (global.set $global$0
+ (local.get $3)
+ )
+ )
+ (i32.store offset=4
+ (local.get $2)
+ (local.get $1)
+ )
+ (i32.store
+ (local.get $2)
+ (local.get $0)
+ )
+ (drop
+ (call $printf
+ (i32.const 568)
+ (local.get $2)
+ )
+ )
+ (block
+ (if
+ (i32.lt_u
+ (local.tee $4
+ (i32.add
+ (local.get $2)
+ (i32.const 16)
+ )
+ )
+ (global.get $__stack_limit)
+ )
+ (call $__handle_stack_overflow)
+ )
+ (global.set $global$0
+ (local.get $4)
+ )
+ )
+ (i32.add
+ (local.get $1)
+ (local.get $0)
+ )
+ )
+ (func $__original_main (; 4 ;) (type $2) (result i32)
+ (local $0 i32)
+ (local $1 i32)
+ (local $2 i32)
+ (block
+ (if
+ (i32.lt_u
+ (local.tee $1
+ (local.tee $0
+ (i32.sub
+ (global.get $global$0)
+ (i32.const 16)
+ )
+ )
+ )
+ (global.get $__stack_limit)
+ )
+ (call $__handle_stack_overflow)
+ )
+ (global.set $global$0
+ (local.get $1)
+ )
+ )
+ (i32.store
+ (local.get $0)
+ (call $foo
+ (i32.const 1)
+ (i32.const 2)
+ )
+ )
+ (drop
+ (call $printf
+ (i32.const 575)
+ (local.get $0)
+ )
+ )
+ (block
+ (if
+ (i32.lt_u
+ (local.tee $2
+ (i32.add
+ (local.get $0)
+ (i32.const 16)
+ )
+ )
+ (global.get $__stack_limit)
+ )
+ (call $__handle_stack_overflow)
+ )
+ (global.set $global$0
+ (local.get $2)
+ )
+ )
+ (i32.const 0)
+ )
+ (func $main (; 5 ;) (type $0) (param $0 i32) (param $1 i32) (result i32)
+ (call $__original_main)
+ )
+ (func $__set_stack_limit (; 6 ;) (param $0 i32)
+ (global.set $__stack_limit
+ (local.get $0)
+ )
+ )
+ (func $stackSave (; 7 ;) (result i32)
+ (global.get $global$0)
+ )
+ (func $stackAlloc (; 8 ;) (param $0 i32) (result i32)
+ (local $1 i32)
+ (local $2 i32)
+ (block
+ (if
+ (i32.lt_u
+ (local.tee $2
+ (local.tee $1
+ (i32.and
+ (i32.sub
+ (global.get $global$0)
+ (local.get $0)
+ )
+ (i32.const -16)
+ )
+ )
+ )
+ (global.get $__stack_limit)
+ )
+ (call $__handle_stack_overflow)
+ )
+ (global.set $global$0
+ (local.get $2)
+ )
+ )
+ (local.get $1)
+ )
+ (func $stackRestore (; 9 ;) (param $0 i32)
+ (local $1 i32)
+ (if
+ (i32.lt_u
+ (local.tee $1
+ (local.get $0)
+ )
+ (global.get $__stack_limit)
+ )
+ (call $__handle_stack_overflow)
+ )
+ (global.set $global$0
+ (local.get $1)
+ )
+ )
+ (func $__growWasmMemory (; 10 ;) (param $newSize i32) (result i32)
+ (memory.grow
+ (local.get $newSize)
+ )
+ )
+)
+(;
+--BEGIN METADATA --
+{
+ "staticBump": 19,
+ "tableSize": 1,
+ "initializers": [
+ "__wasm_call_ctors"
+ ],
+ "declares": [
+ "printf",
+ "__handle_stack_overflow"
+ ],
+ "externs": [
+ ],
+ "implementedFunctions": [
+ "___wasm_call_ctors",
+ "_main",
+ "___set_stack_limit",
+ "_stackSave",
+ "_stackAlloc",
+ "_stackRestore",
+ "___growWasmMemory"
+ ],
+ "exports": [
+ "__wasm_call_ctors",
+ "main",
+ "__set_stack_limit",
+ "stackSave",
+ "stackAlloc",
+ "stackRestore",
+ "__growWasmMemory"
+ ],
+ "namedGlobals": {
+ "__heap_base" : "66128",
+ "__data_end" : "587"
+ },
+ "invokeFuncs": [
+ ],
+ "features": [
+ ],
+ "mainReadsParams": 0
+}
+-- END METADATA --
+;)