summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGuanzhong Chen <gzchen@google.com>2019-08-02 14:15:58 -0700
committerGitHub <noreply@github.com>2019-08-02 14:15:58 -0700
commit4f0d960ef686dff7d635cb6051d07111e6e27a27 (patch)
treefec00a82eaf658f01f13f7d831b47071bf08a9e6 /src
parent8d4d43f0f239877b10789d3a85deb92f1927bc2e (diff)
downloadbinaryen-4f0d960ef686dff7d635cb6051d07111e6e27a27.tar.gz
binaryen-4f0d960ef686dff7d635cb6051d07111e6e27a27.tar.bz2
binaryen-4f0d960ef686dff7d635cb6051d07111e6e27a27.zip
Implement --check-stack-overflow flag for wasm-emscripten-finalize (#2278)
Diffstat (limited to 'src')
-rw-r--r--src/tools/wasm-emscripten-finalize.cpp12
-rw-r--r--src/wasm-emscripten.h6
-rw-r--r--src/wasm/wasm-emscripten.cpp121
3 files changed, 135 insertions, 4 deletions
diff --git a/src/tools/wasm-emscripten-finalize.cpp b/src/tools/wasm-emscripten-finalize.cpp
index 56a6c3b39..f5afb3e24 100644
--- a/src/tools/wasm-emscripten-finalize.cpp
+++ b/src/tools/wasm-emscripten-finalize.cpp
@@ -48,6 +48,7 @@ int main(int argc, const char* argv[]) {
bool debugInfo = false;
bool isSideModule = false;
bool legalizeJavaScriptFFI = true;
+ bool checkStackOverflow = false;
uint64_t globalBase = INVALID_BASE;
ToolOptions options("wasm-emscripten-finalize",
"Performs Emscripten-specific transforms on .wasm files");
@@ -127,6 +128,13 @@ int main(int argc, const char* argv[]) {
[&dataSegmentFile](Options* o, const std::string& argument) {
dataSegmentFile = argument;
})
+ .add("--check-stack-overflow",
+ "",
+ "Check for stack overflows every time the stack is extended",
+ Options::Arguments::Zero,
+ [&checkStackOverflow](Options* o, const std::string&) {
+ checkStackOverflow = true;
+ })
.add_positional("INFILE",
Options::Arguments::One,
[&infile](Options* o, const std::string& argument) {
@@ -200,6 +208,10 @@ int main(int argc, const char* argv[]) {
}
wasm.updateMaps();
+ if (checkStackOverflow && !isSideModule) {
+ generator.enforceStackLimit();
+ }
+
if (isSideModule) {
generator.replaceStackPointerGlobal();
generator.generatePostInstantiateFunction();
diff --git a/src/wasm-emscripten.h b/src/wasm-emscripten.h
index 0841708f2..1758aa23a 100644
--- a/src/wasm-emscripten.h
+++ b/src/wasm-emscripten.h
@@ -54,6 +54,8 @@ public:
void fixInvokeFunctionNames();
+ void enforceStackLimit();
+
// Emits the data segments to a file. The file contains data from address base
// onwards (we must pass in base, as we can't tell it from the wasm - the
// first segment may start after a run of zeros, but we need those zeros in
@@ -71,11 +73,13 @@ private:
Global* getStackPointerGlobal();
Expression* generateLoadStackPointer();
- Expression* generateStoreStackPointer(Expression* value);
+ Expression* generateStoreStackPointer(Function* func, Expression* value);
void generateDynCallThunk(std::string sig);
void generateStackSaveFunction();
void generateStackAllocFunction();
void generateStackRestoreFunction();
+ void generateSetStackLimitFunction();
+ Name importStackOverflowHandler();
};
} // namespace wasm
diff --git a/src/wasm/wasm-emscripten.cpp b/src/wasm/wasm-emscripten.cpp
index 3ee3e4424..c21016baa 100644
--- a/src/wasm/wasm-emscripten.cpp
+++ b/src/wasm/wasm-emscripten.cpp
@@ -37,8 +37,11 @@ static Name STACK_SAVE("stackSave");
static Name STACK_RESTORE("stackRestore");
static Name STACK_ALLOC("stackAlloc");
static Name STACK_INIT("stack$init");
+static Name STACK_LIMIT("__stack_limit");
+static Name SET_STACK_LIMIT("__set_stack_limit");
static Name POST_INSTANTIATE("__post_instantiate");
static Name ASSIGN_GOT_ENTIRES("__assign_got_enties");
+static Name STACK_OVERFLOW_IMPORT("__handle_stack_overflow");
void addExportedFunction(Module& wasm, Function* function) {
wasm.addFunction(function);
@@ -92,8 +95,32 @@ Expression* EmscriptenGlueGenerator::generateLoadStackPointer() {
return builder.makeGlobalGet(stackPointer->name, i32);
}
+inline Expression* stackBoundsCheck(Builder& builder,
+ Function* func,
+ Expression* value,
+ Global* stackPointer,
+ Global* stackLimit,
+ Name handler) {
+ // Add a local to store the value of the expression. We need the value twice:
+ // once to check if it has overflowed, and again to assign to store it.
+ auto newSP = Builder::addVar(func, stackPointer->type);
+ // (if (i32.lt_u (local.tee $newSP (...value...)) (global.get $__stack_limit))
+ // (call $handler))
+ auto check =
+ builder.makeIf(builder.makeBinary(
+ BinaryOp::LtUInt32,
+ builder.makeLocalTee(newSP, value),
+ builder.makeGlobalGet(stackLimit->name, stackLimit->type)),
+ builder.makeCall(handler, {}, none));
+ // (global.set $__stack_pointer (local.get $newSP))
+ auto newSet = builder.makeGlobalSet(
+ stackPointer->name, builder.makeLocalGet(newSP, stackPointer->type));
+ return builder.blockify(check, newSet);
+}
+
Expression*
-EmscriptenGlueGenerator::generateStoreStackPointer(Expression* value) {
+EmscriptenGlueGenerator::generateStoreStackPointer(Function* func,
+ Expression* value) {
if (!useStackPointerGlobal) {
return builder.makeStore(
/* bytes =*/4,
@@ -107,6 +134,14 @@ EmscriptenGlueGenerator::generateStoreStackPointer(Expression* value) {
if (!stackPointer) {
Fatal() << "stack pointer global not found";
}
+ if (auto* stackLimit = wasm.getGlobalOrNull(STACK_LIMIT)) {
+ return stackBoundsCheck(builder,
+ func,
+ value,
+ stackPointer,
+ stackLimit,
+ importStackOverflowHandler());
+ }
return builder.makeGlobalSet(stackPointer->name, value);
}
@@ -132,7 +167,7 @@ void EmscriptenGlueGenerator::generateStackAllocFunction() {
Const* subConst = builder.makeConst(Literal(~bitMask));
Binary* maskedSub = builder.makeBinary(AndInt32, sub, subConst);
LocalSet* teeStackLocal = builder.makeLocalTee(1, maskedSub);
- Expression* storeStack = generateStoreStackPointer(teeStackLocal);
+ Expression* storeStack = generateStoreStackPointer(function, teeStackLocal);
Block* block = builder.makeBlock();
block->list.push_back(storeStack);
@@ -149,7 +184,7 @@ void EmscriptenGlueGenerator::generateStackRestoreFunction() {
Function* function =
builder.makeFunction(STACK_RESTORE, std::move(params), none, {});
LocalGet* getArg = builder.makeLocalGet(0, i32);
- Expression* store = generateStoreStackPointer(getArg);
+ Expression* store = generateStoreStackPointer(function, getArg);
function->body = store;
@@ -444,6 +479,86 @@ void EmscriptenGlueGenerator::replaceStackPointerGlobal() {
wasm.removeGlobal(stackPointer->name);
}
+struct StackLimitEnforcer : public WalkerPass<PostWalker<StackLimitEnforcer>> {
+ StackLimitEnforcer(Global* stackPointer,
+ Global* stackLimit,
+ Builder& builder,
+ Name handler)
+ : stackPointer(stackPointer), stackLimit(stackLimit), builder(builder),
+ handler(handler) {}
+
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override {
+ return new StackLimitEnforcer(stackPointer, stackLimit, builder, handler);
+ }
+
+ void visitGlobalSet(GlobalSet* curr) {
+ if (getModule()->getGlobalOrNull(curr->name) == stackPointer) {
+ replaceCurrent(stackBoundsCheck(builder,
+ getFunction(),
+ curr->value,
+ stackPointer,
+ stackLimit,
+ handler));
+ }
+ }
+
+private:
+ Global* stackPointer;
+ Global* stackLimit;
+ Builder& builder;
+ Name handler;
+};
+
+void EmscriptenGlueGenerator::enforceStackLimit() {
+ Global* stackPointer = getStackPointerGlobal();
+ if (!stackPointer) {
+ return;
+ }
+
+ auto* stackLimit = builder.makeGlobal(STACK_LIMIT,
+ stackPointer->type,
+ builder.makeConst(Literal(0)),
+ Builder::Mutable);
+ wasm.addGlobal(stackLimit);
+
+ auto handler = importStackOverflowHandler();
+
+ StackLimitEnforcer walker(stackPointer, stackLimit, builder, handler);
+ PassRunner runner(&wasm);
+ walker.run(&runner, &wasm);
+
+ generateSetStackLimitFunction();
+}
+
+void EmscriptenGlueGenerator::generateSetStackLimitFunction() {
+ Function* function =
+ builder.makeFunction(SET_STACK_LIMIT, std::vector<Type>({i32}), none, {});
+ LocalGet* getArg = builder.makeLocalGet(0, i32);
+ Expression* store = builder.makeGlobalSet(STACK_LIMIT, getArg);
+ function->body = store;
+ addExportedFunction(wasm, function);
+}
+
+Name EmscriptenGlueGenerator::importStackOverflowHandler() {
+ ImportInfo info(wasm);
+
+ if (auto* existing = info.getImportedFunction(ENV, STACK_OVERFLOW_IMPORT)) {
+ return existing->name;
+ } else {
+ auto* import = new Function;
+ import->name = STACK_OVERFLOW_IMPORT;
+ import->module = ENV;
+ import->base = STACK_OVERFLOW_IMPORT;
+ auto* functionType = ensureFunctionType("v", &wasm);
+ import->type = functionType->name;
+ FunctionTypeUtils::fillFunction(import, functionType);
+ wasm.addFunction(import);
+ return STACK_OVERFLOW_IMPORT;
+ }
+}
+
const Address UNKNOWN_OFFSET(uint32_t(-1));
std::vector<Address> getSegmentOffsets(Module& wasm) {