diff options
Diffstat (limited to 'src/passes')
-rw-r--r-- | src/passes/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/passes/StackCheck.cpp | 157 | ||||
-rw-r--r-- | src/passes/pass.cpp | 3 | ||||
-rw-r--r-- | src/passes/passes.h | 1 |
4 files changed, 162 insertions, 0 deletions
diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt index e54c5ca18..8a51e3ac8 100644 --- a/src/passes/CMakeLists.txt +++ b/src/passes/CMakeLists.txt @@ -75,6 +75,7 @@ set(passes_SOURCES SimplifyLocals.cpp Souperify.cpp SpillPointers.cpp + StackCheck.cpp SSAify.cpp Untee.cpp Vacuum.cpp diff --git a/src/passes/StackCheck.cpp b/src/passes/StackCheck.cpp new file mode 100644 index 000000000..4c797351b --- /dev/null +++ b/src/passes/StackCheck.cpp @@ -0,0 +1,157 @@ +/* + * Copyright 2020 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +// Enforce stack pointer limits. This pass will add checks around all +// assignments to the __stack_pointer global that LLVM uses for its +// shadow stack. +// + +#include "abi/js.h" +#include "ir/import-utils.h" +#include "pass.h" +#include "shared-constants.h" +#include "support/debug.h" +#include "wasm-emscripten.h" + +#define DEBUG_TYPE "stack-check" + +namespace wasm { + +static Name STACK_LIMIT("__stack_limit"); +static Name SET_STACK_LIMIT("__set_stack_limit"); + +static void importStackOverflowHandler(Module& module, Name name) { + ImportInfo info(module); + + if (!info.getImportedFunction(ENV, name)) { + auto* import = new Function; + import->name = name; + import->module = ENV; + import->base = name; + import->sig = Signature(Type::none, Type::none); + module.addFunction(import); + } +} + +static void addExportedFunction(Module& module, Function* function) { + module.addFunction(function); + auto export_ = new Export; + export_->name = export_->value = function->name; + export_->kind = ExternalKind::Function; + module.addExport(export_); +} + +static void generateSetStackLimitFunction(Module& module) { + Builder builder(module); + Function* function = + builder.makeFunction(SET_STACK_LIMIT, Signature(Type::i32, Type::none), {}); + LocalGet* getArg = builder.makeLocalGet(0, Type::i32); + Expression* store = builder.makeGlobalSet(STACK_LIMIT, getArg); + function->body = store; + addExportedFunction(module, function); +} + +struct EnforceStackLimit : public WalkerPass<PostWalker<EnforceStackLimit>> { + EnforceStackLimit(Global* stackPointer, + Global* stackLimit, + Builder& builder, + Name handler) + : stackPointer(stackPointer), stackLimit(stackLimit), builder(builder), + handler(handler) {} + + bool isFunctionParallel() override { return true; } + + Pass* create() override { + return new EnforceStackLimit(stackPointer, stackLimit, builder, handler); + } + + Expression* stackBoundsCheck(Function* func, + Expression* value, + Global* stackPointer, + Global* stackLimit) { + // Add a local to store the value of the expression. We need the value + // twice: once to check if it has overflowed, and again to assign to store + // it. + auto newSP = Builder::addVar(func, stackPointer->type); + // If we imported a handler, call it. That can show a nice error in JS. + // Otherwise, just trap. + Expression* handlerExpr; + if (handler.is()) { + handlerExpr = builder.makeCall(handler, {}, Type::none); + } else { + handlerExpr = builder.makeUnreachable(); + } + // (if (i32.lt_u (local.tee $newSP (...val...)) (global.get $__stack_limit)) + auto check = builder.makeIf( + builder.makeBinary( + BinaryOp::LtUInt32, + builder.makeLocalTee(newSP, value, stackPointer->type), + builder.makeGlobalGet(stackLimit->name, stackLimit->type)), + handlerExpr); + // (global.set $__stack_pointer (local.get $newSP)) + auto newSet = builder.makeGlobalSet( + stackPointer->name, builder.makeLocalGet(newSP, stackPointer->type)); + return builder.blockify(check, newSet); + } + + void visitGlobalSet(GlobalSet* curr) { + if (getModule()->getGlobalOrNull(curr->name) == stackPointer) { + replaceCurrent( + stackBoundsCheck(getFunction(), curr->value, stackPointer, stackLimit)); + } + } + +private: + Global* stackPointer; + Global* stackLimit; + Builder& builder; + Name handler; +}; + +struct StackCheck : public Pass { + void run(PassRunner* runner, Module* module) override { + Global* stackPointer = getStackPointerGlobal(*module); + if (!stackPointer) { + BYN_DEBUG(std::cerr << "no stack pointer found\n"); + return; + } + + Name handler; + auto handlerName = + runner->options.getArgumentOrDefault("stack-check-handler", ""); + if (handlerName != "") { + handler = handlerName; + importStackOverflowHandler(*module, handler); + } + + Builder builder(*module); + Global* stackLimit = builder.makeGlobal(STACK_LIMIT, + stackPointer->type, + builder.makeConst(int32_t(0)), + Builder::Mutable); + module->addGlobal(stackLimit); + + PassRunner innerRunner(module); + EnforceStackLimit(stackPointer, stackLimit, builder, handler) + .run(&innerRunner, module); + generateSetStackLimitFunction(*module); + } +}; + +Pass* createStackCheckPass() { return new StackCheck; } + +} // namespace wasm diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index 9ff06b153..14bf5d185 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -343,6 +343,9 @@ void PassRegistry::registerPasses() { createSSAifyNoMergePass); registerPass( "strip", "deprecated; same as strip-debug", createStripDebugPass); + registerPass("stack-check", + "enforce limits on llvm's __stack_pointer global", + createStackCheckPass); registerPass("strip-debug", "strip debug info (including the names section)", createStripDebugPass); diff --git a/src/passes/passes.h b/src/passes/passes.h index ff94ef51c..ae4fa1e87 100644 --- a/src/passes/passes.h +++ b/src/passes/passes.h @@ -109,6 +109,7 @@ Pass* createSimplifyLocalsNoNestingPass(); Pass* createSimplifyLocalsNoTeePass(); Pass* createSimplifyLocalsNoStructurePass(); Pass* createSimplifyLocalsNoTeeNoStructurePass(); +Pass* createStackCheckPass(); Pass* createStripDebugPass(); Pass* createStripDWARFPass(); Pass* createStripProducersPass(); |