diff options
Diffstat (limited to 'src/passes/StackCheck.cpp')
-rw-r--r-- | src/passes/StackCheck.cpp | 157 |
1 files changed, 157 insertions, 0 deletions
diff --git a/src/passes/StackCheck.cpp b/src/passes/StackCheck.cpp new file mode 100644 index 000000000..4c797351b --- /dev/null +++ b/src/passes/StackCheck.cpp @@ -0,0 +1,157 @@ +/* + * Copyright 2020 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +// Enforce stack pointer limits. This pass will add checks around all +// assignments to the __stack_pointer global that LLVM uses for its +// shadow stack. +// + +#include "abi/js.h" +#include "ir/import-utils.h" +#include "pass.h" +#include "shared-constants.h" +#include "support/debug.h" +#include "wasm-emscripten.h" + +#define DEBUG_TYPE "stack-check" + +namespace wasm { + +static Name STACK_LIMIT("__stack_limit"); +static Name SET_STACK_LIMIT("__set_stack_limit"); + +static void importStackOverflowHandler(Module& module, Name name) { + ImportInfo info(module); + + if (!info.getImportedFunction(ENV, name)) { + auto* import = new Function; + import->name = name; + import->module = ENV; + import->base = name; + import->sig = Signature(Type::none, Type::none); + module.addFunction(import); + } +} + +static void addExportedFunction(Module& module, Function* function) { + module.addFunction(function); + auto export_ = new Export; + export_->name = export_->value = function->name; + export_->kind = ExternalKind::Function; + module.addExport(export_); +} + +static void generateSetStackLimitFunction(Module& module) { + Builder builder(module); + Function* function = + builder.makeFunction(SET_STACK_LIMIT, Signature(Type::i32, Type::none), {}); + LocalGet* getArg = builder.makeLocalGet(0, Type::i32); + Expression* store = builder.makeGlobalSet(STACK_LIMIT, getArg); + function->body = store; + addExportedFunction(module, function); +} + +struct EnforceStackLimit : public WalkerPass<PostWalker<EnforceStackLimit>> { + EnforceStackLimit(Global* stackPointer, + Global* stackLimit, + Builder& builder, + Name handler) + : stackPointer(stackPointer), stackLimit(stackLimit), builder(builder), + handler(handler) {} + + bool isFunctionParallel() override { return true; } + + Pass* create() override { + return new EnforceStackLimit(stackPointer, stackLimit, builder, handler); + } + + Expression* stackBoundsCheck(Function* func, + Expression* value, + Global* stackPointer, + Global* stackLimit) { + // Add a local to store the value of the expression. We need the value + // twice: once to check if it has overflowed, and again to assign to store + // it. + auto newSP = Builder::addVar(func, stackPointer->type); + // If we imported a handler, call it. That can show a nice error in JS. + // Otherwise, just trap. + Expression* handlerExpr; + if (handler.is()) { + handlerExpr = builder.makeCall(handler, {}, Type::none); + } else { + handlerExpr = builder.makeUnreachable(); + } + // (if (i32.lt_u (local.tee $newSP (...val...)) (global.get $__stack_limit)) + auto check = builder.makeIf( + builder.makeBinary( + BinaryOp::LtUInt32, + builder.makeLocalTee(newSP, value, stackPointer->type), + builder.makeGlobalGet(stackLimit->name, stackLimit->type)), + handlerExpr); + // (global.set $__stack_pointer (local.get $newSP)) + auto newSet = builder.makeGlobalSet( + stackPointer->name, builder.makeLocalGet(newSP, stackPointer->type)); + return builder.blockify(check, newSet); + } + + void visitGlobalSet(GlobalSet* curr) { + if (getModule()->getGlobalOrNull(curr->name) == stackPointer) { + replaceCurrent( + stackBoundsCheck(getFunction(), curr->value, stackPointer, stackLimit)); + } + } + +private: + Global* stackPointer; + Global* stackLimit; + Builder& builder; + Name handler; +}; + +struct StackCheck : public Pass { + void run(PassRunner* runner, Module* module) override { + Global* stackPointer = getStackPointerGlobal(*module); + if (!stackPointer) { + BYN_DEBUG(std::cerr << "no stack pointer found\n"); + return; + } + + Name handler; + auto handlerName = + runner->options.getArgumentOrDefault("stack-check-handler", ""); + if (handlerName != "") { + handler = handlerName; + importStackOverflowHandler(*module, handler); + } + + Builder builder(*module); + Global* stackLimit = builder.makeGlobal(STACK_LIMIT, + stackPointer->type, + builder.makeConst(int32_t(0)), + Builder::Mutable); + module->addGlobal(stackLimit); + + PassRunner innerRunner(module); + EnforceStackLimit(stackPointer, stackLimit, builder, handler) + .run(&innerRunner, module); + generateSetStackLimitFunction(*module); + } +}; + +Pass* createStackCheckPass() { return new StackCheck; } + +} // namespace wasm |