diff options
author | Alon Zakai <azakai@google.com> | 2021-03-29 15:32:01 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-03-29 15:32:01 -0700 |
commit | 09cba0fa50b0492ecf7b7886180dbd6c5aa5d04d (patch) | |
tree | bc654bcc92cab64a08fe0bd9197c2f9c4a324ca4 | |
parent | 244f886cb2f1d9c5dddbf2dea5d47e6b0a434c5d (diff) | |
download | binaryen-09cba0fa50b0492ecf7b7886180dbd6c5aa5d04d.tar.gz binaryen-09cba0fa50b0492ecf7b7886180dbd6c5aa5d04d.tar.bz2 binaryen-09cba0fa50b0492ecf7b7886180dbd6c5aa5d04d.zip |
Scan module-level code in necessary places (#3744)
Several old passes like DeadArgumentElimination and DuplicateFunctionElimination
need to look at all ref.funcs, and they scanned functions for that, but that is not
enough as such an instruction might appear in a global initializer. To fix this, add a
walkModuleCode method.
walkModuleCode is useful when doing the pattern of creating a function-parallel
pass to scan functions quickly, but we also want to do the same scanning of code
at the module level. This allows doing so in a single line.
(It is also possible to just do walk() on the entire module, which will find all code,
but that is not function-parallel. Perhaps we should have a walkParallel() option
to simplify this further in a followup, and that would call walkModuleCode afterwards
etc.)
Also add some missing validation and comments in the validator about issues that
I noticed in relation to the new testcases here.
-rw-r--r-- | src/passes/DeadArgumentElimination.cpp | 7 | ||||
-rw-r--r-- | src/passes/Inlining.cpp | 14 | ||||
-rw-r--r-- | src/passes/opt-utils.h | 9 | ||||
-rw-r--r-- | src/wasm-traversal.h | 19 | ||||
-rw-r--r-- | src/wasm/wasm-validator.cpp | 14 | ||||
-rw-r--r-- | test/example/c-api-kitchen-sink.c | 6 | ||||
-rw-r--r-- | test/passes/dae_all-features.txt | 20 | ||||
-rw-r--r-- | test/passes/dae_all-features.wast | 21 | ||||
-rw-r--r-- | test/passes/duplicate-function-elimination_all-features.txt | 13 | ||||
-rw-r--r-- | test/passes/duplicate-function-elimination_all-features.wast | 18 |
10 files changed, 119 insertions, 22 deletions
diff --git a/src/passes/DeadArgumentElimination.cpp b/src/passes/DeadArgumentElimination.cpp index c4de07d81..975b291b0 100644 --- a/src/passes/DeadArgumentElimination.cpp +++ b/src/passes/DeadArgumentElimination.cpp @@ -279,16 +279,15 @@ struct DAE : public Pass { for (auto& func : module->functions) { infoMap[func->name]; } - // Check the influence of the table and exports. + DAEScanner scanner(&infoMap); + scanner.walkModuleCode(module); for (auto& curr : module->exports) { if (curr->kind == ExternalKind::Function) { infoMap[curr->value].hasUnseenCalls = true; } } - ElementUtils::iterAllElementFunctionNames( - module, [&](Name name) { infoMap[name].hasUnseenCalls = true; }); // Scan all the functions. - DAEScanner(&infoMap).run(runner, module); + scanner.run(runner, module); // Combine all the info. std::unordered_map<Name, std::vector<Call*>> allCalls; std::unordered_set<Name> tailCallees; diff --git a/src/passes/Inlining.cpp b/src/passes/Inlining.cpp index b69d8f7c2..0e6f6abe4 100644 --- a/src/passes/Inlining.cpp +++ b/src/passes/Inlining.cpp @@ -344,23 +344,15 @@ struct Inlining : public Pass { infos[func->name]; } PassRunner runner(module); - FunctionInfoScanner(&infos).run(&runner, module); + FunctionInfoScanner scanner(&infos); + scanner.run(&runner, module); // fill in global uses + scanner.walkModuleCode(module); for (auto& ex : module->exports) { if (ex->kind == ExternalKind::Function) { infos[ex->value].usedGlobally = true; } } - ElementUtils::iterAllElementFunctionNames( - module, [&](Name name) { infos[name].usedGlobally = true; }); - - for (auto& global : module->globals) { - if (!global->imported()) { - for (auto* ref : FindAll<RefFunc>(global->init).list) { - infos[ref->func].usedGlobally = true; - } - } - } if (module->start.is()) { infos[module->start].usedGlobally = true; } diff --git a/src/passes/opt-utils.h b/src/passes/opt-utils.h index b333779f7..5f4ab545c 100644 --- a/src/passes/opt-utils.h +++ b/src/passes/opt-utils.h @@ -21,6 +21,7 @@ #include <unordered_set> #include <ir/element-utils.h> +#include <ir/module-utils.h> #include <pass.h> #include <wasm.h> @@ -84,10 +85,10 @@ inline void replaceFunctions(PassRunner* runner, name = iter->second; } }; - // replace direct calls - FunctionRefReplacer(maybeReplace).run(runner, &module); - // replace in table - ElementUtils::iterAllElementFunctionNames(&module, maybeReplace); + // replace direct calls in code both functions and module elements + FunctionRefReplacer replacer(maybeReplace); + replacer.run(runner, &module); + replacer.walkModuleCode(&module); // replace in start if (module.start.is()) { diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h index 1f307c318..388584fd6 100644 --- a/src/wasm-traversal.h +++ b/src/wasm-traversal.h @@ -259,6 +259,25 @@ struct Walker : public VisitorType { self->walkMemory(&module->memory); } + // Walks module-level code, that is, code that is not in functions. + void walkModuleCode(Module* module) { + // Dispatch statically through the SubType. + SubType* self = static_cast<SubType*>(this); + for (auto& curr : module->globals) { + if (!curr->imported()) { + self->walk(curr->init); + } + } + for (auto& curr : module->elementSegments) { + if (curr->offset) { + self->walk(curr->offset); + } + for (auto* item : curr->data) { + self->walk(item); + } + } + } + // Walk implementation. We don't use recursion as ASTs may be highly // nested. diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index 425792e22..8bc12cadf 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -1998,6 +1998,13 @@ void FunctionValidator::visitRefFunc(RefFunc* curr) { shouldBeTrue(curr->type.isFunction(), curr, "ref.func must have a function reference type"); + // TODO: verify it also has a typed function references type, and the right + // one, + // curr->type.getHeapType().getSignature() + // That is blocked on having the ability to create signature types in the C + // API (for now those users create the type with funcref). This also needs to + // be fixed in LegalizeJSInterface and FuncCastEmulation and other places that + // update function types. // TODO: check for non-nullability } @@ -2843,6 +2850,9 @@ static void validateTables(Module& module, ValidationInfo& info) { auto table = module.getTableOrNull(segment->table); info.shouldBeTrue( table != nullptr, "elem", "elem segment must have a valid table name"); + info.shouldBeTrue(!!segment->offset, + "elem", + "table segment offset should have an offset"); info.shouldBeEqual(segment->offset->type, Type(Type::i32), segment->offset, @@ -2853,6 +2863,10 @@ static void validateTables(Module& module, ValidationInfo& info) { segment->offset, "table segment offset should be reasonable"); validator.validate(segment->offset); + } else { + info.shouldBeTrue(!segment->offset, + "elem", + "non-table segment offset should have no offset"); } // Avoid double checking items if (module.features.hasReferenceTypes()) { diff --git a/test/example/c-api-kitchen-sink.c b/test/example/c-api-kitchen-sink.c index 9e6366d36..aa82723b4 100644 --- a/test/example/c-api-kitchen-sink.c +++ b/test/example/c-api-kitchen-sink.c @@ -893,12 +893,12 @@ void test_core() { BinaryenModuleSetFeatures(module, features); assert(BinaryenModuleGetFeatures(module) == features); - // Verify it validates - assert(BinaryenModuleValidate(module)); - // Print it out BinaryenModulePrint(module); + // Verify it validates + assert(BinaryenModuleValidate(module)); + // Clean up the module, which owns all the objects we created above BinaryenModuleDispose(module); } diff --git a/test/passes/dae_all-features.txt b/test/passes/dae_all-features.txt index 4ca21576f..ea439bc9e 100644 --- a/test/passes/dae_all-features.txt +++ b/test/passes/dae_all-features.txt @@ -292,3 +292,23 @@ (ref.func $0) ) ) +(module + (type $none_=>_none (func)) + (type $i64 (func (param i64))) + (global $global$0 (ref $i64) (ref.func $0)) + (export "even" (func $1)) + (func $0 (param $0 i64) + (unreachable) + ) + (func $1 + (call_ref + (i64.const 0) + (global.get $global$0) + ) + ) + (func $2 + (call $0 + (i64.const 0) + ) + ) +) diff --git a/test/passes/dae_all-features.wast b/test/passes/dae_all-features.wast index 097c144cc..55e935f3b 100644 --- a/test/passes/dae_all-features.wast +++ b/test/passes/dae_all-features.wast @@ -172,3 +172,24 @@ (ref.func $0) ) ) +(module + (type $i64 (func (param i64))) + (global $global$0 (ref $i64) (ref.func $0)) + (export "even" (func $1)) + ;; the argument to this function cannot be removed due to the ref.func of it + ;; in a global + (func $0 (param $0 i64) + (unreachable) + ) + (func $1 + (call_ref + (i64.const 0) + (global.get $global$0) + ) + ) + (func $2 + (call $0 + (i64.const 0) + ) + ) +) diff --git a/test/passes/duplicate-function-elimination_all-features.txt b/test/passes/duplicate-function-elimination_all-features.txt index a7a751f76..6dd5a1004 100644 --- a/test/passes/duplicate-function-elimination_all-features.txt +++ b/test/passes/duplicate-function-elimination_all-features.txt @@ -19,3 +19,16 @@ (nop) ) ) +(module + (type $func (func (result i32))) + (global $global$0 (ref $func) (ref.func $foo)) + (export "export" (func $2)) + (func $foo (result i32) + (unreachable) + ) + (func $2 (result i32) + (call_ref + (global.get $global$0) + ) + ) +) diff --git a/test/passes/duplicate-function-elimination_all-features.wast b/test/passes/duplicate-function-elimination_all-features.wast index 116542c96..1d04e878c 100644 --- a/test/passes/duplicate-function-elimination_all-features.wast +++ b/test/passes/duplicate-function-elimination_all-features.wast @@ -21,3 +21,21 @@ (func $foo ;; happens to share a name with the memory ) ) +;; renaming after deduplication must update ref.funcs in globals +(module + (type $func (func (result i32))) + (global $global$0 (ref $func) (ref.func $bar)) + ;; These two identical functions can be merged. The ref.func in the global must + ;; be updated accordingly. + (func $foo (result i32) + (unreachable) + ) + (func $bar (result i32) + (unreachable) + ) + (func "export" (result i32) + (call_ref + (global.get $global$0) + ) + ) +) |