summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSam Clegg <sbc@chromium.org>2020-07-27 22:23:17 -0700
committerGitHub <noreply@github.com>2020-07-27 22:23:17 -0700
commit32ab8bac04af52121c6985a9a019c0fdec957f03 (patch)
tree07f9471a68be9cf5c2176e92c69e3a94bbf9df9f
parent85e45a4371ef9e6b143e9675d5f52136ef881c12 (diff)
downloadbinaryen-32ab8bac04af52121c6985a9a019c0fdec957f03.tar.gz
binaryen-32ab8bac04af52121c6985a9a019c0fdec957f03.tar.bz2
binaryen-32ab8bac04af52121c6985a9a019c0fdec957f03.zip
Move stack-check into its own pass (#2994)
This new pass takes an optional stack-check-handler argument which is the name of the function to call on stack overflow. If no argument is passed then it just traps.
-rw-r--r--src/pass.h1
-rw-r--r--src/passes/CMakeLists.txt1
-rw-r--r--src/passes/StackCheck.cpp157
-rw-r--r--src/passes/pass.cpp3
-rw-r--r--src/passes/passes.h1
-rw-r--r--src/tools/wasm-emscripten-finalize.cpp13
-rw-r--r--src/wasm-emscripten.h4
-rw-r--r--src/wasm/wasm-emscripten.cpp111
8 files changed, 172 insertions, 119 deletions
diff --git a/src/pass.h b/src/pass.h
index 760feb7c0..8b8e73fa0 100644
--- a/src/pass.h
+++ b/src/pass.h
@@ -153,7 +153,6 @@ struct PassRunner {
PassRunner(const PassRunner&) = delete;
PassRunner& operator=(const PassRunner&) = delete;
- void setOptions(PassOptions newOptions) { options = newOptions; }
void setDebug(bool debug) {
options.debug = debug;
// validate everything by default if debugging
diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt
index e54c5ca18..8a51e3ac8 100644
--- a/src/passes/CMakeLists.txt
+++ b/src/passes/CMakeLists.txt
@@ -75,6 +75,7 @@ set(passes_SOURCES
SimplifyLocals.cpp
Souperify.cpp
SpillPointers.cpp
+ StackCheck.cpp
SSAify.cpp
Untee.cpp
Vacuum.cpp
diff --git a/src/passes/StackCheck.cpp b/src/passes/StackCheck.cpp
new file mode 100644
index 000000000..4c797351b
--- /dev/null
+++ b/src/passes/StackCheck.cpp
@@ -0,0 +1,157 @@
+/*
+ * Copyright 2020 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//
+// Enforce stack pointer limits. This pass will add checks around all
+// assignments to the __stack_pointer global that LLVM uses for its
+// shadow stack.
+//
+
+#include "abi/js.h"
+#include "ir/import-utils.h"
+#include "pass.h"
+#include "shared-constants.h"
+#include "support/debug.h"
+#include "wasm-emscripten.h"
+
+#define DEBUG_TYPE "stack-check"
+
+namespace wasm {
+
+static Name STACK_LIMIT("__stack_limit");
+static Name SET_STACK_LIMIT("__set_stack_limit");
+
+static void importStackOverflowHandler(Module& module, Name name) {
+ ImportInfo info(module);
+
+ if (!info.getImportedFunction(ENV, name)) {
+ auto* import = new Function;
+ import->name = name;
+ import->module = ENV;
+ import->base = name;
+ import->sig = Signature(Type::none, Type::none);
+ module.addFunction(import);
+ }
+}
+
+static void addExportedFunction(Module& module, Function* function) {
+ module.addFunction(function);
+ auto export_ = new Export;
+ export_->name = export_->value = function->name;
+ export_->kind = ExternalKind::Function;
+ module.addExport(export_);
+}
+
+static void generateSetStackLimitFunction(Module& module) {
+ Builder builder(module);
+ Function* function =
+ builder.makeFunction(SET_STACK_LIMIT, Signature(Type::i32, Type::none), {});
+ LocalGet* getArg = builder.makeLocalGet(0, Type::i32);
+ Expression* store = builder.makeGlobalSet(STACK_LIMIT, getArg);
+ function->body = store;
+ addExportedFunction(module, function);
+}
+
+struct EnforceStackLimit : public WalkerPass<PostWalker<EnforceStackLimit>> {
+ EnforceStackLimit(Global* stackPointer,
+ Global* stackLimit,
+ Builder& builder,
+ Name handler)
+ : stackPointer(stackPointer), stackLimit(stackLimit), builder(builder),
+ handler(handler) {}
+
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override {
+ return new EnforceStackLimit(stackPointer, stackLimit, builder, handler);
+ }
+
+ Expression* stackBoundsCheck(Function* func,
+ Expression* value,
+ Global* stackPointer,
+ Global* stackLimit) {
+ // Add a local to store the value of the expression. We need the value
+ // twice: once to check if it has overflowed, and again to assign to store
+ // it.
+ auto newSP = Builder::addVar(func, stackPointer->type);
+ // If we imported a handler, call it. That can show a nice error in JS.
+ // Otherwise, just trap.
+ Expression* handlerExpr;
+ if (handler.is()) {
+ handlerExpr = builder.makeCall(handler, {}, Type::none);
+ } else {
+ handlerExpr = builder.makeUnreachable();
+ }
+ // (if (i32.lt_u (local.tee $newSP (...val...)) (global.get $__stack_limit))
+ auto check = builder.makeIf(
+ builder.makeBinary(
+ BinaryOp::LtUInt32,
+ builder.makeLocalTee(newSP, value, stackPointer->type),
+ builder.makeGlobalGet(stackLimit->name, stackLimit->type)),
+ handlerExpr);
+ // (global.set $__stack_pointer (local.get $newSP))
+ auto newSet = builder.makeGlobalSet(
+ stackPointer->name, builder.makeLocalGet(newSP, stackPointer->type));
+ return builder.blockify(check, newSet);
+ }
+
+ void visitGlobalSet(GlobalSet* curr) {
+ if (getModule()->getGlobalOrNull(curr->name) == stackPointer) {
+ replaceCurrent(
+ stackBoundsCheck(getFunction(), curr->value, stackPointer, stackLimit));
+ }
+ }
+
+private:
+ Global* stackPointer;
+ Global* stackLimit;
+ Builder& builder;
+ Name handler;
+};
+
+struct StackCheck : public Pass {
+ void run(PassRunner* runner, Module* module) override {
+ Global* stackPointer = getStackPointerGlobal(*module);
+ if (!stackPointer) {
+ BYN_DEBUG(std::cerr << "no stack pointer found\n");
+ return;
+ }
+
+ Name handler;
+ auto handlerName =
+ runner->options.getArgumentOrDefault("stack-check-handler", "");
+ if (handlerName != "") {
+ handler = handlerName;
+ importStackOverflowHandler(*module, handler);
+ }
+
+ Builder builder(*module);
+ Global* stackLimit = builder.makeGlobal(STACK_LIMIT,
+ stackPointer->type,
+ builder.makeConst(int32_t(0)),
+ Builder::Mutable);
+ module->addGlobal(stackLimit);
+
+ PassRunner innerRunner(module);
+ EnforceStackLimit(stackPointer, stackLimit, builder, handler)
+ .run(&innerRunner, module);
+ generateSetStackLimitFunction(*module);
+ }
+};
+
+Pass* createStackCheckPass() { return new StackCheck; }
+
+} // namespace wasm
diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp
index 9ff06b153..14bf5d185 100644
--- a/src/passes/pass.cpp
+++ b/src/passes/pass.cpp
@@ -343,6 +343,9 @@ void PassRegistry::registerPasses() {
createSSAifyNoMergePass);
registerPass(
"strip", "deprecated; same as strip-debug", createStripDebugPass);
+ registerPass("stack-check",
+ "enforce limits on llvm's __stack_pointer global",
+ createStackCheckPass);
registerPass("strip-debug",
"strip debug info (including the names section)",
createStripDebugPass);
diff --git a/src/passes/passes.h b/src/passes/passes.h
index ff94ef51c..ae4fa1e87 100644
--- a/src/passes/passes.h
+++ b/src/passes/passes.h
@@ -109,6 +109,7 @@ Pass* createSimplifyLocalsNoNestingPass();
Pass* createSimplifyLocalsNoTeePass();
Pass* createSimplifyLocalsNoStructurePass();
Pass* createSimplifyLocalsNoTeeNoStructurePass();
+Pass* createStackCheckPass();
Pass* createStripDebugPass();
Pass* createStripDWARFPass();
Pass* createStripProducersPass();
diff --git a/src/tools/wasm-emscripten-finalize.cpp b/src/tools/wasm-emscripten-finalize.cpp
index 6e0831a08..cc57c9311 100644
--- a/src/tools/wasm-emscripten-finalize.cpp
+++ b/src/tools/wasm-emscripten-finalize.cpp
@@ -242,7 +242,15 @@ int main(int argc, const char* argv[]) {
wasm.updateMaps();
if (checkStackOverflow && !sideModule) {
- generator.enforceStackLimit();
+ PassOptions options;
+ if (!standaloneWasm) {
+ // In standalone mode we don't set a handler at all.. which means
+ // just trap on overflow.
+ options.arguments["stack-check-handler"] = "__handle_stack_overflow";
+ }
+ PassRunner passRunner(&wasm, options);
+ passRunner.add("stack-check");
+ passRunner.run();
}
if (sideModule) {
@@ -288,8 +296,7 @@ int main(int argc, const char* argv[]) {
// Legalize the wasm, if BigInts don't make that moot.
if (!bigInt) {
BYN_TRACE("legalizing types\n");
- PassRunner passRunner(&wasm);
- passRunner.setOptions(options.passOptions);
+ PassRunner passRunner(&wasm, options.passOptions);
passRunner.setDebug(options.debug);
passRunner.setDebugInfo(debugInfo);
passRunner.add(ABI::getLegalizationPass(
diff --git a/src/wasm-emscripten.h b/src/wasm-emscripten.h
index ed02b6409..ba956ee39 100644
--- a/src/wasm-emscripten.h
+++ b/src/wasm-emscripten.h
@@ -54,8 +54,6 @@ public:
void fixInvokeFunctionNames();
- void enforceStackLimit();
-
// clang uses name mangling to rename the argc/argv form of main to
// __main_argc_argv. Emscripten in non-standalone mode expects that function
// to be exported as main. This function renames __main_argc_argv to main
@@ -80,8 +78,6 @@ private:
std::unordered_set<Signature> sigs;
void generateDynCallThunk(Signature sig);
- void generateSetStackLimitFunction();
- Name importStackOverflowHandler();
};
} // namespace wasm
diff --git a/src/wasm/wasm-emscripten.cpp b/src/wasm/wasm-emscripten.cpp
index 0a9f6108b..f37bfe907 100644
--- a/src/wasm/wasm-emscripten.cpp
+++ b/src/wasm/wasm-emscripten.cpp
@@ -38,10 +38,7 @@ cashew::IString EM_ASM_PREFIX("emscripten_asm_const");
cashew::IString EM_JS_PREFIX("__em_js__");
static Name STACK_INIT("stack$init");
-static Name STACK_LIMIT("__stack_limit");
-static Name SET_STACK_LIMIT("__set_stack_limit");
static Name POST_INSTANTIATE("__post_instantiate");
-static Name STACK_OVERFLOW_IMPORT("__handle_stack_overflow");
void addExportedFunction(Module& wasm, Function* function) {
wasm.addFunction(function);
@@ -211,114 +208,6 @@ void EmscriptenGlueGenerator::internalizeStackPointerGlobal() {
wasm.addGlobal(sp);
}
-struct StackLimitEnforcer : public WalkerPass<PostWalker<StackLimitEnforcer>> {
- StackLimitEnforcer(Global* stackPointer,
- Global* stackLimit,
- Builder& builder,
- Name handler)
- : stackPointer(stackPointer), stackLimit(stackLimit), builder(builder),
- handler(handler) {}
-
- bool isFunctionParallel() override { return true; }
-
- Pass* create() override {
- return new StackLimitEnforcer(stackPointer, stackLimit, builder, handler);
- }
-
- Expression* stackBoundsCheck(Function* func,
- Expression* value,
- Global* stackPointer,
- Global* stackLimit) {
- // Add a local to store the value of the expression. We need the value
- // twice: once to check if it has overflowed, and again to assign to store
- // it.
- auto newSP = Builder::addVar(func, stackPointer->type);
- // If we imported a handler, call it. That can show a nice error in JS.
- // Otherwise, just trap.
- Expression* handlerExpr;
- if (handler.is()) {
- handlerExpr = builder.makeCall(handler, {}, Type::none);
- } else {
- handlerExpr = builder.makeUnreachable();
- }
- // (if (i32.lt_u (local.tee $newSP (...val...)) (global.get $__stack_limit))
- auto check = builder.makeIf(
- builder.makeBinary(
- BinaryOp::LtUInt32,
- builder.makeLocalTee(newSP, value, stackPointer->type),
- builder.makeGlobalGet(stackLimit->name, stackLimit->type)),
- handlerExpr);
- // (global.set $__stack_pointer (local.get $newSP))
- auto newSet = builder.makeGlobalSet(
- stackPointer->name, builder.makeLocalGet(newSP, stackPointer->type));
- return builder.blockify(check, newSet);
- }
-
- void visitGlobalSet(GlobalSet* curr) {
- if (getModule()->getGlobalOrNull(curr->name) == stackPointer) {
- replaceCurrent(
- stackBoundsCheck(getFunction(), curr->value, stackPointer, stackLimit));
- }
- }
-
-private:
- Global* stackPointer;
- Global* stackLimit;
- Builder& builder;
- Name handler;
-};
-
-void EmscriptenGlueGenerator::enforceStackLimit() {
- Global* stackPointer = getStackPointerGlobal(wasm);
- if (!stackPointer) {
- return;
- }
-
- auto* stackLimit = builder.makeGlobal(STACK_LIMIT,
- stackPointer->type,
- builder.makeConst(int32_t(0)),
- Builder::Mutable);
- wasm.addGlobal(stackLimit);
-
- Name handler = importStackOverflowHandler();
- StackLimitEnforcer walker(stackPointer, stackLimit, builder, handler);
- PassRunner runner(&wasm);
- walker.run(&runner, &wasm);
-
- generateSetStackLimitFunction();
-}
-
-void EmscriptenGlueGenerator::generateSetStackLimitFunction() {
- Function* function =
- builder.makeFunction(SET_STACK_LIMIT, Signature(Type::i32, Type::none), {});
- LocalGet* getArg = builder.makeLocalGet(0, Type::i32);
- Expression* store = builder.makeGlobalSet(STACK_LIMIT, getArg);
- function->body = store;
- addExportedFunction(wasm, function);
-}
-
-Name EmscriptenGlueGenerator::importStackOverflowHandler() {
- // We can call an import to handle stack overflows normally, but not in
- // standalone mode, where we can't import from JS.
- if (standalone) {
- return Name();
- }
-
- ImportInfo info(wasm);
-
- if (auto* existing = info.getImportedFunction(ENV, STACK_OVERFLOW_IMPORT)) {
- return existing->name;
- } else {
- auto* import = new Function;
- import->name = STACK_OVERFLOW_IMPORT;
- import->module = ENV;
- import->base = STACK_OVERFLOW_IMPORT;
- import->sig = Signature(Type::none, Type::none);
- wasm.addFunction(import);
- return STACK_OVERFLOW_IMPORT;
- }
-}
-
const Address UNKNOWN_OFFSET(uint32_t(-1));
std::vector<Address> getSegmentOffsets(Module& wasm) {