summaryrefslogtreecommitdiff
path: root/src/passes/StackCheck.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/passes/StackCheck.cpp')
-rw-r--r--src/passes/StackCheck.cpp77
1 files changed, 53 insertions, 24 deletions
diff --git a/src/passes/StackCheck.cpp b/src/passes/StackCheck.cpp
index 4c797351b..183420fba 100644
--- a/src/passes/StackCheck.cpp
+++ b/src/passes/StackCheck.cpp
@@ -31,8 +31,16 @@
namespace wasm {
+// The base is where the stack begins. As it goes down, that is the highest
+// valid address.
+static Name STACK_BASE("__stack_base");
+// The limit is the farthest it can grow to, which is the lowest valid address.
static Name STACK_LIMIT("__stack_limit");
+// Old version, which sets the limit.
+// TODO: remove this
static Name SET_STACK_LIMIT("__set_stack_limit");
+// New version, which sets the base and the limit.
+static Name SET_STACK_LIMITS("__set_stack_limits");
static void importStackOverflowHandler(Module& module, Name name) {
ImportInfo info(module);
@@ -55,34 +63,43 @@ static void addExportedFunction(Module& module, Function* function) {
module.addExport(export_);
}
-static void generateSetStackLimitFunction(Module& module) {
+static void generateSetStackLimitFunctions(Module& module) {
Builder builder(module);
- Function* function =
+ // One-parameter version
+ Function* limitFunc =
builder.makeFunction(SET_STACK_LIMIT, Signature(Type::i32, Type::none), {});
LocalGet* getArg = builder.makeLocalGet(0, Type::i32);
Expression* store = builder.makeGlobalSet(STACK_LIMIT, getArg);
- function->body = store;
- addExportedFunction(module, function);
+ limitFunc->body = store;
+ addExportedFunction(module, limitFunc);
+ // Two-parameter version
+ Function* limitsFunc = builder.makeFunction(
+ SET_STACK_LIMITS, Signature({Type::i32, Type::i32}, Type::none), {});
+ LocalGet* getBase = builder.makeLocalGet(0, Type::i32);
+ Expression* storeBase = builder.makeGlobalSet(STACK_BASE, getBase);
+ LocalGet* getLimit = builder.makeLocalGet(1, Type::i32);
+ Expression* storeLimit = builder.makeGlobalSet(STACK_LIMIT, getLimit);
+ limitsFunc->body = builder.makeBlock({storeBase, storeLimit});
+ addExportedFunction(module, limitsFunc);
}
-struct EnforceStackLimit : public WalkerPass<PostWalker<EnforceStackLimit>> {
- EnforceStackLimit(Global* stackPointer,
- Global* stackLimit,
- Builder& builder,
- Name handler)
- : stackPointer(stackPointer), stackLimit(stackLimit), builder(builder),
- handler(handler) {}
+struct EnforceStackLimits : public WalkerPass<PostWalker<EnforceStackLimits>> {
+ EnforceStackLimits(Global* stackPointer,
+ Global* stackBase,
+ Global* stackLimit,
+ Builder& builder,
+ Name handler)
+ : stackPointer(stackPointer), stackBase(stackBase), stackLimit(stackLimit),
+ builder(builder), handler(handler) {}
bool isFunctionParallel() override { return true; }
Pass* create() override {
- return new EnforceStackLimit(stackPointer, stackLimit, builder, handler);
+ return new EnforceStackLimits(
+ stackPointer, stackBase, stackLimit, builder, handler);
}
- Expression* stackBoundsCheck(Function* func,
- Expression* value,
- Global* stackPointer,
- Global* stackLimit) {
+ Expression* stackBoundsCheck(Function* func, Expression* value) {
// Add a local to store the value of the expression. We need the value
// twice: once to check if it has overflowed, and again to assign to store
// it.
@@ -95,12 +112,18 @@ struct EnforceStackLimit : public WalkerPass<PostWalker<EnforceStackLimit>> {
} else {
handlerExpr = builder.makeUnreachable();
}
- // (if (i32.lt_u (local.tee $newSP (...val...)) (global.get $__stack_limit))
+ // If it is >= the base or <= the limit, then error.
auto check = builder.makeIf(
builder.makeBinary(
- BinaryOp::LtUInt32,
- builder.makeLocalTee(newSP, value, stackPointer->type),
- builder.makeGlobalGet(stackLimit->name, stackLimit->type)),
+ BinaryOp::OrInt32,
+ builder.makeBinary(
+ BinaryOp::GtUInt32,
+ builder.makeLocalTee(newSP, value, stackPointer->type),
+ builder.makeGlobalGet(stackBase->name, stackBase->type)),
+ builder.makeBinary(
+ BinaryOp::LtUInt32,
+ builder.makeLocalGet(newSP, stackPointer->type),
+ builder.makeGlobalGet(stackLimit->name, stackLimit->type))),
handlerExpr);
// (global.set $__stack_pointer (local.get $newSP))
auto newSet = builder.makeGlobalSet(
@@ -110,13 +133,13 @@ struct EnforceStackLimit : public WalkerPass<PostWalker<EnforceStackLimit>> {
void visitGlobalSet(GlobalSet* curr) {
if (getModule()->getGlobalOrNull(curr->name) == stackPointer) {
- replaceCurrent(
- stackBoundsCheck(getFunction(), curr->value, stackPointer, stackLimit));
+ replaceCurrent(stackBoundsCheck(getFunction(), curr->value));
}
}
private:
Global* stackPointer;
+ Global* stackBase;
Global* stackLimit;
Builder& builder;
Name handler;
@@ -139,6 +162,12 @@ struct StackCheck : public Pass {
}
Builder builder(*module);
+ Global* stackBase = builder.makeGlobal(STACK_BASE,
+ stackPointer->type,
+ builder.makeConst(int32_t(0)),
+ Builder::Mutable);
+ module->addGlobal(stackBase);
+
Global* stackLimit = builder.makeGlobal(STACK_LIMIT,
stackPointer->type,
builder.makeConst(int32_t(0)),
@@ -146,9 +175,9 @@ struct StackCheck : public Pass {
module->addGlobal(stackLimit);
PassRunner innerRunner(module);
- EnforceStackLimit(stackPointer, stackLimit, builder, handler)
+ EnforceStackLimits(stackPointer, stackBase, stackLimit, builder, handler)
.run(&innerRunner, module);
- generateSetStackLimitFunction(*module);
+ generateSetStackLimitFunctions(*module);
}
};