/* * Copyright 2020 WebAssembly Community Group participants * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // // Enforce stack pointer limits. This pass will add checks around all // assignments to the __stack_pointer global that LLVM uses for its // shadow stack. // #include "abi/js.h" #include "ir/import-utils.h" #include "pass.h" #include "shared-constants.h" #include "support/debug.h" #include "wasm-emscripten.h" #define DEBUG_TYPE "stack-check" namespace wasm { static Name STACK_LIMIT("__stack_limit"); static Name SET_STACK_LIMIT("__set_stack_limit"); static void importStackOverflowHandler(Module& module, Name name) { ImportInfo info(module); if (!info.getImportedFunction(ENV, name)) { auto* import = new Function; import->name = name; import->module = ENV; import->base = name; import->sig = Signature(Type::none, Type::none); module.addFunction(import); } } static void addExportedFunction(Module& module, Function* function) { module.addFunction(function); auto export_ = new Export; export_->name = export_->value = function->name; export_->kind = ExternalKind::Function; module.addExport(export_); } static void generateSetStackLimitFunction(Module& module) { Builder builder(module); Function* function = builder.makeFunction(SET_STACK_LIMIT, Signature(Type::i32, Type::none), {}); LocalGet* getArg = builder.makeLocalGet(0, Type::i32); Expression* store = builder.makeGlobalSet(STACK_LIMIT, getArg); function->body = store; addExportedFunction(module, function); } struct EnforceStackLimit : public WalkerPass> { EnforceStackLimit(Global* stackPointer, Global* stackLimit, Builder& builder, Name handler) : stackPointer(stackPointer), stackLimit(stackLimit), builder(builder), handler(handler) {} bool isFunctionParallel() override { return true; } Pass* create() override { return new EnforceStackLimit(stackPointer, stackLimit, builder, handler); } Expression* stackBoundsCheck(Function* func, Expression* value, Global* stackPointer, Global* stackLimit) { // Add a local to store the value of the expression. We need the value // twice: once to check if it has overflowed, and again to assign to store // it. auto newSP = Builder::addVar(func, stackPointer->type); // If we imported a handler, call it. That can show a nice error in JS. // Otherwise, just trap. Expression* handlerExpr; if (handler.is()) { handlerExpr = builder.makeCall(handler, {}, Type::none); } else { handlerExpr = builder.makeUnreachable(); } // (if (i32.lt_u (local.tee $newSP (...val...)) (global.get $__stack_limit)) auto check = builder.makeIf( builder.makeBinary( BinaryOp::LtUInt32, builder.makeLocalTee(newSP, value, stackPointer->type), builder.makeGlobalGet(stackLimit->name, stackLimit->type)), handlerExpr); // (global.set $__stack_pointer (local.get $newSP)) auto newSet = builder.makeGlobalSet( stackPointer->name, builder.makeLocalGet(newSP, stackPointer->type)); return builder.blockify(check, newSet); } void visitGlobalSet(GlobalSet* curr) { if (getModule()->getGlobalOrNull(curr->name) == stackPointer) { replaceCurrent( stackBoundsCheck(getFunction(), curr->value, stackPointer, stackLimit)); } } private: Global* stackPointer; Global* stackLimit; Builder& builder; Name handler; }; struct StackCheck : public Pass { void run(PassRunner* runner, Module* module) override { Global* stackPointer = getStackPointerGlobal(*module); if (!stackPointer) { BYN_DEBUG(std::cerr << "no stack pointer found\n"); return; } Name handler; auto handlerName = runner->options.getArgumentOrDefault("stack-check-handler", ""); if (handlerName != "") { handler = handlerName; importStackOverflowHandler(*module, handler); } Builder builder(*module); Global* stackLimit = builder.makeGlobal(STACK_LIMIT, stackPointer->type, builder.makeConst(int32_t(0)), Builder::Mutable); module->addGlobal(stackLimit); PassRunner innerRunner(module); EnforceStackLimit(stackPointer, stackLimit, builder, handler) .run(&innerRunner, module); generateSetStackLimitFunction(*module); } }; Pass* createStackCheckPass() { return new StackCheck; } } // namespace wasm