summaryrefslogtreecommitdiff
path: root/src/passes/StackCheck.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/passes/StackCheck.cpp')
-rw-r--r--src/passes/StackCheck.cpp157
1 files changed, 157 insertions, 0 deletions
diff --git a/src/passes/StackCheck.cpp b/src/passes/StackCheck.cpp
new file mode 100644
index 000000000..4c797351b
--- /dev/null
+++ b/src/passes/StackCheck.cpp
@@ -0,0 +1,157 @@
+/*
+ * Copyright 2020 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//
+// Enforce stack pointer limits. This pass will add checks around all
+// assignments to the __stack_pointer global that LLVM uses for its
+// shadow stack.
+//
+
+#include "abi/js.h"
+#include "ir/import-utils.h"
+#include "pass.h"
+#include "shared-constants.h"
+#include "support/debug.h"
+#include "wasm-emscripten.h"
+
+#define DEBUG_TYPE "stack-check"
+
+namespace wasm {
+
+static Name STACK_LIMIT("__stack_limit");
+static Name SET_STACK_LIMIT("__set_stack_limit");
+
+static void importStackOverflowHandler(Module& module, Name name) {
+ ImportInfo info(module);
+
+ if (!info.getImportedFunction(ENV, name)) {
+ auto* import = new Function;
+ import->name = name;
+ import->module = ENV;
+ import->base = name;
+ import->sig = Signature(Type::none, Type::none);
+ module.addFunction(import);
+ }
+}
+
+static void addExportedFunction(Module& module, Function* function) {
+ module.addFunction(function);
+ auto export_ = new Export;
+ export_->name = export_->value = function->name;
+ export_->kind = ExternalKind::Function;
+ module.addExport(export_);
+}
+
+static void generateSetStackLimitFunction(Module& module) {
+ Builder builder(module);
+ Function* function =
+ builder.makeFunction(SET_STACK_LIMIT, Signature(Type::i32, Type::none), {});
+ LocalGet* getArg = builder.makeLocalGet(0, Type::i32);
+ Expression* store = builder.makeGlobalSet(STACK_LIMIT, getArg);
+ function->body = store;
+ addExportedFunction(module, function);
+}
+
+struct EnforceStackLimit : public WalkerPass<PostWalker<EnforceStackLimit>> {
+ EnforceStackLimit(Global* stackPointer,
+ Global* stackLimit,
+ Builder& builder,
+ Name handler)
+ : stackPointer(stackPointer), stackLimit(stackLimit), builder(builder),
+ handler(handler) {}
+
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override {
+ return new EnforceStackLimit(stackPointer, stackLimit, builder, handler);
+ }
+
+ Expression* stackBoundsCheck(Function* func,
+ Expression* value,
+ Global* stackPointer,
+ Global* stackLimit) {
+ // Add a local to store the value of the expression. We need the value
+ // twice: once to check if it has overflowed, and again to assign to store
+ // it.
+ auto newSP = Builder::addVar(func, stackPointer->type);
+ // If we imported a handler, call it. That can show a nice error in JS.
+ // Otherwise, just trap.
+ Expression* handlerExpr;
+ if (handler.is()) {
+ handlerExpr = builder.makeCall(handler, {}, Type::none);
+ } else {
+ handlerExpr = builder.makeUnreachable();
+ }
+ // (if (i32.lt_u (local.tee $newSP (...val...)) (global.get $__stack_limit))
+ auto check = builder.makeIf(
+ builder.makeBinary(
+ BinaryOp::LtUInt32,
+ builder.makeLocalTee(newSP, value, stackPointer->type),
+ builder.makeGlobalGet(stackLimit->name, stackLimit->type)),
+ handlerExpr);
+ // (global.set $__stack_pointer (local.get $newSP))
+ auto newSet = builder.makeGlobalSet(
+ stackPointer->name, builder.makeLocalGet(newSP, stackPointer->type));
+ return builder.blockify(check, newSet);
+ }
+
+ void visitGlobalSet(GlobalSet* curr) {
+ if (getModule()->getGlobalOrNull(curr->name) == stackPointer) {
+ replaceCurrent(
+ stackBoundsCheck(getFunction(), curr->value, stackPointer, stackLimit));
+ }
+ }
+
+private:
+ Global* stackPointer;
+ Global* stackLimit;
+ Builder& builder;
+ Name handler;
+};
+
+struct StackCheck : public Pass {
+ void run(PassRunner* runner, Module* module) override {
+ Global* stackPointer = getStackPointerGlobal(*module);
+ if (!stackPointer) {
+ BYN_DEBUG(std::cerr << "no stack pointer found\n");
+ return;
+ }
+
+ Name handler;
+ auto handlerName =
+ runner->options.getArgumentOrDefault("stack-check-handler", "");
+ if (handlerName != "") {
+ handler = handlerName;
+ importStackOverflowHandler(*module, handler);
+ }
+
+ Builder builder(*module);
+ Global* stackLimit = builder.makeGlobal(STACK_LIMIT,
+ stackPointer->type,
+ builder.makeConst(int32_t(0)),
+ Builder::Mutable);
+ module->addGlobal(stackLimit);
+
+ PassRunner innerRunner(module);
+ EnforceStackLimit(stackPointer, stackLimit, builder, handler)
+ .run(&innerRunner, module);
+ generateSetStackLimitFunction(*module);
+ }
+};
+
+Pass* createStackCheckPass() { return new StackCheck; }
+
+} // namespace wasm