diff options
-rw-r--r-- | src/passes/Directize.cpp | 30 | ||||
-rw-r--r-- | src/passes/pass.cpp | 4 | ||||
-rw-r--r-- | test/passes/directize.wast | 182 |
3 files changed, 200 insertions, 16 deletions
diff --git a/src/passes/Directize.cpp b/src/passes/Directize.cpp index fe6d73b0d..d9400cce7 100644 --- a/src/passes/Directize.cpp +++ b/src/passes/Directize.cpp @@ -26,8 +26,6 @@ #include "pass.h" #include "wasm-builder.h" #include "wasm-traversal.h" -#include "find_all.h" -#include "ir/module-utils.h" #include "asm_v_wasm.h" namespace wasm { @@ -41,8 +39,9 @@ struct FlatTable { FlatTable(Table& table) { valid = true; for (auto& segment : table.segments) { - auto offset = segment->offset; + auto offset = segment.offset; if (!offset->is<Const>()) { + // TODO: handle some non-constant segments valid = false; return; } @@ -58,12 +57,12 @@ struct FlatTable { } }; -struct FunctionDirectizer : public WalkerPass<FunctionDirectizer> { +struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> { bool isFunctionParallel() override { return true; } - Pass* create() override { return new Scanner(flatTable); } + Pass* create() override { return new FunctionDirectizer(flatTable); } - Scanner(FlatTable* flatTable) : flatTable(flatTable) {} + FunctionDirectizer(FlatTable* flatTable) : flatTable(flatTable) {} void visitCallIndirect(CallIndirect* curr) { if (auto* c = curr->target->dynCast<Const>()) { @@ -72,11 +71,15 @@ struct FunctionDirectizer : public WalkerPass<FunctionDirectizer> { // emit an unreachable here, since in Binaryen it is ok to // reorder/replace traps when optimizing (but never to // remove them, at least not by default). - if (index >= flatTable.names.size()) { + if (index >= flatTable->names.size()) { + replaceWithUnreachable(); + return; + } + auto name = flatTable->names[index]; + if (!name.is()) { replaceWithUnreachable(); return; } - auto name = flatTable.names[index]; auto* func = getModule()->getFunction(name); if (getSig(getModule()->getFunctionType(curr->fullType)) != getSig(func)) { @@ -96,16 +99,16 @@ private: FlatTable* flatTable; void replaceWithUnreachable() { - return replaceCurrent(Builder(*getModule()).makeUnreachable)); + replaceCurrent(Builder(*getModule()).makeUnreachable()); } }; -struct Directize : public WalkerPass<Directize> { +struct Directize : public Pass { void run(PassRunner* runner, Module* module) override { - if (!module->table.exists()) return; + if (!module->table.exists) return; if (module->table.imported()) return; - for (auto& export : module->exports) { - if (export->kind == ExternalKind::Table) return; + for (auto& ex : module->exports) { + if (ex->kind == ExternalKind::Table) return; } FlatTable flatTable(module->table); if (!flatTable.valid) return; @@ -119,7 +122,6 @@ struct Directize : public WalkerPass<Directize> { } }; - } // anonymous namespace Pass *createDirectizePass() { diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index a5cfc2116..f20cc19fd 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -75,7 +75,7 @@ void PassRegistry::registerPasses() { registerPass("code-folding", "fold code, merging duplicates", createCodeFoldingPass); registerPass("const-hoisting", "hoist repeated constants to a local", createConstHoistingPass); registerPass("dce", "removes unreachable code", createDeadCodeEliminationPass); - registerPass("directize", "turns indirect calls into direct ones", createDirectize); + registerPass("directize", "turns indirect calls into direct ones", createDirectizePass); registerPass("dfo", "optimizes using the DataFlow SSA IR", createDataFlowOptsPass); registerPass("duplicate-function-elimination", "removes duplicate functions", createDuplicateFunctionEliminationPass); registerPass("extract-function", "leaves just one function (useful for debugging)", createExtractFunctionPass); @@ -229,8 +229,8 @@ void PassRunner::addDefaultGlobalOptimizationPrePasses() { } void PassRunner::addDefaultGlobalOptimizationPostPasses() { - add("directize"); if (options.optimizeLevel >= 2 || options.shrinkLevel >= 1) { + add("directize"); add("dae-optimizing"); } if (options.optimizeLevel >= 2 || options.shrinkLevel >= 2) { diff --git a/test/passes/directize.wast b/test/passes/directize.wast new file mode 100644 index 000000000..8e6839457 --- /dev/null +++ b/test/passes/directize.wast @@ -0,0 +1,182 @@ +(module + (type $ii (func (param i32 i32))) + (table $0 5 5 funcref) + (elem (i32.const 1) $foo) + (func $foo (param i32) (param i32) + (unreachable) + ) + (func $bar (param $x i32) (param $y i32) + (call_indirect (type $ii) + (local.get $x) + (local.get $y) + (i32.const 1) + ) + ) +) +;; at table edges +(module + (type $ii (func (param i32 i32))) + (table $0 5 5 funcref) + (elem (i32.const 4) $foo) + (func $foo (param i32) (param i32) + (unreachable) + ) + (func $bar (param $x i32) (param $y i32) + (call_indirect (type $ii) + (local.get $x) + (local.get $y) + (i32.const 4) + ) + ) +) +(module + (type $ii (func (param i32 i32))) + (table $0 5 5 funcref) + (elem (i32.const 0) $foo) + (func $foo (param i32) (param i32) + (unreachable) + ) + (func $bar (param $x i32) (param $y i32) + (call_indirect (type $ii) + (local.get $x) + (local.get $y) + (i32.const 0) + ) + ) +) +(module + (type $ii (func (param i32 i32))) + (table $0 5 5 funcref) + (elem (i32.const 0) $foo $foo $foo $foo $foo) + (func $foo (param i32) (param i32) + (unreachable) + ) + (func $bar (param $x i32) (param $y i32) + (call_indirect (type $ii) + (local.get $x) + (local.get $y) + (i32.const 2) + ) + ) +) +;; imported table +(module + (type $ii (func (param i32 i32))) + (import "env" "table" (table $table 5 5 funcref)) + (elem (i32.const 1) $foo) + (func $foo (param i32) (param i32) + (unreachable) + ) + (func $bar (param $x i32) (param $y i32) + (call_indirect (type $ii) + (local.get $x) + (local.get $y) + (i32.const 1) + ) + ) +) +;; exported table +(module + (type $ii (func (param i32 i32))) + (table $0 5 5 funcref) + (export "tab" (table $0)) + (elem (i32.const 1) $foo) + (func $foo (param i32) (param i32) + (unreachable) + ) + (func $bar (param $x i32) (param $y i32) + (call_indirect (type $ii) + (local.get $x) + (local.get $y) + (i32.const 1) + ) + ) +) +;; non-constant table offset +(module + (type $ii (func (param i32 i32))) + (table $0 5 5 funcref) + (global $g (mut i32) (i32.const 1)) + (elem (global.get $g) $foo) + (func $foo (param i32) (param i32) + (unreachable) + ) + (func $bar (param $x i32) (param $y i32) + (call_indirect (type $ii) + (local.get $x) + (local.get $y) + (i32.const 1) + ) + ) +) +;; non-constant call index +(module + (type $ii (func (param i32 i32))) + (table $0 5 5 funcref) + (elem (i32.const 1) $foo) + (func $foo (param i32) (param i32) + (unreachable) + ) + (func $bar (param $x i32) (param $y i32) (param $z i32) + (call_indirect (type $ii) + (local.get $x) + (local.get $y) + (local.get $z) + ) + ) +) +;; bad index +(module + (type $ii (func (param i32 i32))) + (table $0 5 5 funcref) + (elem (i32.const 1) $foo) + (func $foo (param i32) (param i32) + (unreachable) + ) + (func $bar (param $x i32) (param $y i32) + (call_indirect (type $ii) + (local.get $x) + (local.get $y) + (i32.const 5) + ) + ) +) +;; missing index +(module + (type $ii (func (param i32 i32))) + (table $0 5 5 funcref) + (elem (i32.const 1) $foo) + (func $foo (param i32) (param i32) + (unreachable) + ) + (func $bar (param $x i32) (param $y i32) + (call_indirect (type $ii) + (local.get $x) + (local.get $y) + (i32.const 2) + ) + ) +) +;; bad type +(module + (type $ii (func (param i32 i32))) + (table $0 5 5 funcref) + (elem (i32.const 1) $foo) + (func $foo (param i32) + (unreachable) + ) + (func $bar (param $x i32) (param $y i32) + (call_indirect (type $ii) + (local.get $x) + (local.get $y) + (i32.const 1) + ) + ) +) +;; no table +(module + (func $foo (param i32) + (unreachable) + ) +) + |