summaryrefslogtreecommitdiff
path: root/src/wasm/wasm-emscripten.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/wasm/wasm-emscripten.cpp')
-rw-r--r--src/wasm/wasm-emscripten.cpp121
1 files changed, 118 insertions, 3 deletions
diff --git a/src/wasm/wasm-emscripten.cpp b/src/wasm/wasm-emscripten.cpp
index 3ee3e4424..c21016baa 100644
--- a/src/wasm/wasm-emscripten.cpp
+++ b/src/wasm/wasm-emscripten.cpp
@@ -37,8 +37,11 @@ static Name STACK_SAVE("stackSave");
static Name STACK_RESTORE("stackRestore");
static Name STACK_ALLOC("stackAlloc");
static Name STACK_INIT("stack$init");
+static Name STACK_LIMIT("__stack_limit");
+static Name SET_STACK_LIMIT("__set_stack_limit");
static Name POST_INSTANTIATE("__post_instantiate");
static Name ASSIGN_GOT_ENTIRES("__assign_got_enties");
+static Name STACK_OVERFLOW_IMPORT("__handle_stack_overflow");
void addExportedFunction(Module& wasm, Function* function) {
wasm.addFunction(function);
@@ -92,8 +95,32 @@ Expression* EmscriptenGlueGenerator::generateLoadStackPointer() {
return builder.makeGlobalGet(stackPointer->name, i32);
}
+inline Expression* stackBoundsCheck(Builder& builder,
+ Function* func,
+ Expression* value,
+ Global* stackPointer,
+ Global* stackLimit,
+ Name handler) {
+ // Add a local to store the value of the expression. We need the value twice:
+ // once to check if it has overflowed, and again to assign to store it.
+ auto newSP = Builder::addVar(func, stackPointer->type);
+ // (if (i32.lt_u (local.tee $newSP (...value...)) (global.get $__stack_limit))
+ // (call $handler))
+ auto check =
+ builder.makeIf(builder.makeBinary(
+ BinaryOp::LtUInt32,
+ builder.makeLocalTee(newSP, value),
+ builder.makeGlobalGet(stackLimit->name, stackLimit->type)),
+ builder.makeCall(handler, {}, none));
+ // (global.set $__stack_pointer (local.get $newSP))
+ auto newSet = builder.makeGlobalSet(
+ stackPointer->name, builder.makeLocalGet(newSP, stackPointer->type));
+ return builder.blockify(check, newSet);
+}
+
Expression*
-EmscriptenGlueGenerator::generateStoreStackPointer(Expression* value) {
+EmscriptenGlueGenerator::generateStoreStackPointer(Function* func,
+ Expression* value) {
if (!useStackPointerGlobal) {
return builder.makeStore(
/* bytes =*/4,
@@ -107,6 +134,14 @@ EmscriptenGlueGenerator::generateStoreStackPointer(Expression* value) {
if (!stackPointer) {
Fatal() << "stack pointer global not found";
}
+ if (auto* stackLimit = wasm.getGlobalOrNull(STACK_LIMIT)) {
+ return stackBoundsCheck(builder,
+ func,
+ value,
+ stackPointer,
+ stackLimit,
+ importStackOverflowHandler());
+ }
return builder.makeGlobalSet(stackPointer->name, value);
}
@@ -132,7 +167,7 @@ void EmscriptenGlueGenerator::generateStackAllocFunction() {
Const* subConst = builder.makeConst(Literal(~bitMask));
Binary* maskedSub = builder.makeBinary(AndInt32, sub, subConst);
LocalSet* teeStackLocal = builder.makeLocalTee(1, maskedSub);
- Expression* storeStack = generateStoreStackPointer(teeStackLocal);
+ Expression* storeStack = generateStoreStackPointer(function, teeStackLocal);
Block* block = builder.makeBlock();
block->list.push_back(storeStack);
@@ -149,7 +184,7 @@ void EmscriptenGlueGenerator::generateStackRestoreFunction() {
Function* function =
builder.makeFunction(STACK_RESTORE, std::move(params), none, {});
LocalGet* getArg = builder.makeLocalGet(0, i32);
- Expression* store = generateStoreStackPointer(getArg);
+ Expression* store = generateStoreStackPointer(function, getArg);
function->body = store;
@@ -444,6 +479,86 @@ void EmscriptenGlueGenerator::replaceStackPointerGlobal() {
wasm.removeGlobal(stackPointer->name);
}
+struct StackLimitEnforcer : public WalkerPass<PostWalker<StackLimitEnforcer>> {
+ StackLimitEnforcer(Global* stackPointer,
+ Global* stackLimit,
+ Builder& builder,
+ Name handler)
+ : stackPointer(stackPointer), stackLimit(stackLimit), builder(builder),
+ handler(handler) {}
+
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override {
+ return new StackLimitEnforcer(stackPointer, stackLimit, builder, handler);
+ }
+
+ void visitGlobalSet(GlobalSet* curr) {
+ if (getModule()->getGlobalOrNull(curr->name) == stackPointer) {
+ replaceCurrent(stackBoundsCheck(builder,
+ getFunction(),
+ curr->value,
+ stackPointer,
+ stackLimit,
+ handler));
+ }
+ }
+
+private:
+ Global* stackPointer;
+ Global* stackLimit;
+ Builder& builder;
+ Name handler;
+};
+
+void EmscriptenGlueGenerator::enforceStackLimit() {
+ Global* stackPointer = getStackPointerGlobal();
+ if (!stackPointer) {
+ return;
+ }
+
+ auto* stackLimit = builder.makeGlobal(STACK_LIMIT,
+ stackPointer->type,
+ builder.makeConst(Literal(0)),
+ Builder::Mutable);
+ wasm.addGlobal(stackLimit);
+
+ auto handler = importStackOverflowHandler();
+
+ StackLimitEnforcer walker(stackPointer, stackLimit, builder, handler);
+ PassRunner runner(&wasm);
+ walker.run(&runner, &wasm);
+
+ generateSetStackLimitFunction();
+}
+
+void EmscriptenGlueGenerator::generateSetStackLimitFunction() {
+ Function* function =
+ builder.makeFunction(SET_STACK_LIMIT, std::vector<Type>({i32}), none, {});
+ LocalGet* getArg = builder.makeLocalGet(0, i32);
+ Expression* store = builder.makeGlobalSet(STACK_LIMIT, getArg);
+ function->body = store;
+ addExportedFunction(wasm, function);
+}
+
+Name EmscriptenGlueGenerator::importStackOverflowHandler() {
+ ImportInfo info(wasm);
+
+ if (auto* existing = info.getImportedFunction(ENV, STACK_OVERFLOW_IMPORT)) {
+ return existing->name;
+ } else {
+ auto* import = new Function;
+ import->name = STACK_OVERFLOW_IMPORT;
+ import->module = ENV;
+ import->base = STACK_OVERFLOW_IMPORT;
+ auto* functionType = ensureFunctionType("v", &wasm);
+ import->type = functionType->name;
+ FunctionTypeUtils::fillFunction(import, functionType);
+ wasm.addFunction(import);
+ return STACK_OVERFLOW_IMPORT;
+ }
+}
+
const Address UNKNOWN_OFFSET(uint32_t(-1));
std::vector<Address> getSegmentOffsets(Module& wasm) {