diff options
author | Alon Zakai <alonzakai@gmail.com> | 2016-05-29 16:39:17 -0700 |
---|---|---|
committer | Alon Zakai <alonzakai@gmail.com> | 2016-05-29 16:39:17 -0700 |
commit | f33f1dbbee7b3f95d8437f1ee60c9075013858b6 (patch) | |
tree | f0f7a4bc5cd7d948f4285298b3b3930f30cc0185 /src | |
parent | 1715b4a1ec845f1dd6b08f48a599f346beb0f758 (diff) | |
parent | 44aeb85b2fa2c743e2d0f7e00349f99cfcbc7639 (diff) | |
download | binaryen-f33f1dbbee7b3f95d8437f1ee60c9075013858b6.tar.gz binaryen-f33f1dbbee7b3f95d8437f1ee60c9075013858b6.tar.bz2 binaryen-f33f1dbbee7b3f95d8437f1ee60c9075013858b6.zip |
Merge pull request #550 from WebAssembly/dfe-nice
Duplicate function elimination
Diffstat (limited to 'src')
-rw-r--r-- | src/ast_utils.h | 378 | ||||
-rw-r--r-- | src/passes/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/passes/DuplicateFunctionElimination.cpp | 179 | ||||
-rw-r--r-- | src/passes/Metrics.cpp | 4 | ||||
-rw-r--r-- | src/passes/ReorderLocals.cpp | 19 | ||||
-rw-r--r-- | src/passes/pass.cpp | 2 | ||||
-rw-r--r-- | src/support/hash.h | 39 | ||||
-rw-r--r-- | src/wasm.h | 12 |
8 files changed, 629 insertions, 5 deletions
diff --git a/src/ast_utils.h b/src/ast_utils.h index a43fc6b2f..ee5c76b69 100644 --- a/src/ast_utils.h +++ b/src/ast_utils.h @@ -17,6 +17,7 @@ #ifndef wasm_ast_utils_h #define wasm_ast_utils_h +#include "support/hash.h" #include "wasm.h" #include "wasm-traversal.h" @@ -241,6 +242,383 @@ struct ExpressionAnalyzer { // The value might be used, so it depends on if the function returns return func->result != none; } + + static bool equal(Expression* left, Expression* right) { + std::vector<Name> nameStack; + std::map<Name, std::vector<Name>> rightNames; // for each name on the left, the stack of names on the right (a stack, since names are scoped and can nest duplicatively + Nop popNameMarker; + std::vector<Expression*> leftStack; + std::vector<Expression*> rightStack; + + auto noteNames = [&](Name left, Name right) { + if (left.is() != right.is()) return false; + if (left.is()) { + nameStack.push_back(left); + rightNames[left].push_back(right); + leftStack.push_back(&popNameMarker); + rightStack.push_back(&popNameMarker); + } + return true; + }; + auto checkNames = [&](Name left, Name right) { + auto iter = rightNames.find(left); + if (iter == rightNames.end()) return left == right; // non-internal name + return iter->second.back() == right; + }; + auto popName = [&]() { + auto left = nameStack.back(); + nameStack.pop_back(); + rightNames[left].pop_back(); + }; + + leftStack.push_back(left); + rightStack.push_back(right); + + while (leftStack.size() > 0 && rightStack.size() > 0) { + left = leftStack.back(); + leftStack.pop_back(); + right = rightStack.back(); + rightStack.pop_back(); + if (!left != !right) return false; + if (!left) continue; + if (left == &popNameMarker) { + popName(); + continue; + } + if (left->_id != right->_id) return false; + #define PUSH(clazz, what) \ + leftStack.push_back(left->cast<clazz>()->what); \ + rightStack.push_back(right->cast<clazz>()->what); + #define CHECK(clazz, what) \ + if (left->cast<clazz>()->what != right->cast<clazz>()->what) return false; + switch (left->_id) { + case Expression::Id::BlockId: { + if (!noteNames(left->cast<Block>()->name, right->cast<Block>()->name)) return false; + CHECK(Block, list.size()); + for (Index i = 0; i < left->cast<Block>()->list.size(); i++) { + PUSH(Block, list[i]); + } + break; + } + case Expression::Id::IfId: { + PUSH(If, condition); + PUSH(If, ifTrue); + PUSH(If, ifFalse); + break; + } + case Expression::Id::LoopId: { + if (!noteNames(left->cast<Loop>()->out, right->cast<Loop>()->out)) return false; + if (!noteNames(left->cast<Loop>()->in, right->cast<Loop>()->in)) return false; + PUSH(Loop, body); + break; + } + case Expression::Id::BreakId: { + if (!checkNames(left->cast<Break>()->name, right->cast<Break>()->name)) return false; + PUSH(Break, condition); + PUSH(Break, value); + break; + } + case Expression::Id::SwitchId: { + CHECK(Switch, targets.size()); + for (Index i = 0; i < left->cast<Switch>()->targets.size(); i++) { + if (!checkNames(left->cast<Switch>()->targets[i], right->cast<Switch>()->targets[i])) return false; + } + if (!checkNames(left->cast<Switch>()->default_, right->cast<Switch>()->default_)) return false; + PUSH(Switch, condition); + PUSH(Switch, value); + break; + } + case Expression::Id::CallId: { + CHECK(Call, target); + CHECK(Call, operands.size()); + for (Index i = 0; i < left->cast<Call>()->operands.size(); i++) { + PUSH(Call, operands[i]); + } + break; + } + case Expression::Id::CallImportId: { + CHECK(CallImport, target); + CHECK(CallImport, operands.size()); + for (Index i = 0; i < left->cast<CallImport>()->operands.size(); i++) { + PUSH(CallImport, operands[i]); + } + break; + } + case Expression::Id::CallIndirectId: { + PUSH(CallIndirect, target); + CHECK(CallIndirect, fullType); + CHECK(CallIndirect, operands.size()); + for (Index i = 0; i < left->cast<CallIndirect>()->operands.size(); i++) { + PUSH(CallIndirect, operands[i]); + } + break; + } + case Expression::Id::GetLocalId: { + CHECK(GetLocal, index); + break; + } + case Expression::Id::SetLocalId: { + CHECK(SetLocal, index); + PUSH(SetLocal, value); + break; + } + case Expression::Id::LoadId: { + CHECK(Load, bytes); + CHECK(Load, signed_); + CHECK(Load, offset); + CHECK(Load, align); + PUSH(Load, ptr); + break; + } + case Expression::Id::StoreId: { + CHECK(Store, bytes); + CHECK(Store, offset); + CHECK(Store, align); + PUSH(Store, ptr); + PUSH(Store, value); + break; + } + case Expression::Id::ConstId: { + CHECK(Const, value); + break; + } + case Expression::Id::UnaryId: { + CHECK(Unary, op); + PUSH(Unary, value); + break; + } + case Expression::Id::BinaryId: { + CHECK(Binary, op); + PUSH(Binary, left); + PUSH(Binary, right); + break; + } + case Expression::Id::SelectId: { + PUSH(Select, ifTrue); + PUSH(Select, ifFalse); + PUSH(Select, condition); + break; + } + case Expression::Id::ReturnId: { + PUSH(Return, value); + break; + } + case Expression::Id::HostId: { + CHECK(Host, op); + CHECK(Host, nameOperand); + CHECK(Host, operands.size()); + for (Index i = 0; i < left->cast<Host>()->operands.size(); i++) { + PUSH(Host, operands[i]); + } + break; + } + case Expression::Id::NopId: { + break; + } + case Expression::Id::UnreachableId: { + break; + } + default: WASM_UNREACHABLE(); + } + #undef CHECK + #undef PUSH + } + if (leftStack.size() > 0 || rightStack.size() > 0) return false; + return true; + } + + // hash an expression, ignoring superficial details like specific internal names + static uint32_t hash(Expression* curr) { + uint32_t digest = 0; + + auto hash = [&digest](uint32_t hash) { + digest = rehash(digest, hash); + }; + auto hash64 = [&digest](uint64_t hash) { + digest = rehash(rehash(digest, hash >> 32), uint32_t(hash)); + }; + + std::vector<Name> nameStack; + Index internalCounter = 0; + std::map<Name, std::vector<Index>> internalNames; // for each internal name, a vector if unique ids + Nop popNameMarker; + std::vector<Expression*> stack; + + auto noteName = [&](Name curr) { + if (curr.is()) { + nameStack.push_back(curr); + internalNames[curr].push_back(internalCounter++); + stack.push_back(&popNameMarker); + } + return true; + }; + auto hashName = [&](Name curr) { + auto iter = internalNames.find(curr); + if (iter == internalNames.end()) hash64(uint64_t(curr.str)); + else hash(iter->second.back()); + }; + auto popName = [&]() { + auto curr = nameStack.back(); + nameStack.pop_back(); + internalNames[curr].pop_back(); + }; + + stack.push_back(curr); + + while (stack.size() > 0) { + curr = stack.back(); + stack.pop_back(); + if (!curr) continue; + if (curr == &popNameMarker) { + popName(); + continue; + } + hash(curr->_id); + #define PUSH(clazz, what) \ + stack.push_back(curr->cast<clazz>()->what); + #define HASH(clazz, what) \ + hash(curr->cast<clazz>()->what); + #define HASH64(clazz, what) \ + hash64(curr->cast<clazz>()->what); + #define HASH_NAME(clazz, what) \ + hash64(uint64_t(curr->cast<clazz>()->what.str)); + #define HASH_PTR(clazz, what) \ + hash64(uint64_t(curr->cast<clazz>()->what)); + switch (curr->_id) { + case Expression::Id::BlockId: { + noteName(curr->cast<Block>()->name); + HASH(Block, list.size()); + for (Index i = 0; i < curr->cast<Block>()->list.size(); i++) { + PUSH(Block, list[i]); + } + break; + } + case Expression::Id::IfId: { + PUSH(If, condition); + PUSH(If, ifTrue); + PUSH(If, ifFalse); + break; + } + case Expression::Id::LoopId: { + noteName(curr->cast<Loop>()->out); + noteName(curr->cast<Loop>()->in); + PUSH(Loop, body); + break; + } + case Expression::Id::BreakId: { + hashName(curr->cast<Break>()->name); + PUSH(Break, condition); + PUSH(Break, value); + break; + } + case Expression::Id::SwitchId: { + HASH(Switch, targets.size()); + for (Index i = 0; i < curr->cast<Switch>()->targets.size(); i++) { + hashName(curr->cast<Switch>()->targets[i]); + } + hashName(curr->cast<Switch>()->default_); + PUSH(Switch, condition); + PUSH(Switch, value); + break; + } + case Expression::Id::CallId: { + HASH_NAME(Call, target); + HASH(Call, operands.size()); + for (Index i = 0; i < curr->cast<Call>()->operands.size(); i++) { + PUSH(Call, operands[i]); + } + break; + } + case Expression::Id::CallImportId: { + HASH_NAME(CallImport, target); + HASH(CallImport, operands.size()); + for (Index i = 0; i < curr->cast<CallImport>()->operands.size(); i++) { + PUSH(CallImport, operands[i]); + } + break; + } + case Expression::Id::CallIndirectId: { + PUSH(CallIndirect, target); + HASH_PTR(CallIndirect, fullType); + HASH(CallIndirect, operands.size()); + for (Index i = 0; i < curr->cast<CallIndirect>()->operands.size(); i++) { + PUSH(CallIndirect, operands[i]); + } + break; + } + case Expression::Id::GetLocalId: { + HASH(GetLocal, index); + break; + } + case Expression::Id::SetLocalId: { + HASH(SetLocal, index); + PUSH(SetLocal, value); + break; + } + case Expression::Id::LoadId: { + HASH(Load, bytes); + HASH(Load, signed_); + HASH(Load, offset); + HASH(Load, align); + PUSH(Load, ptr); + break; + } + case Expression::Id::StoreId: { + HASH(Store, bytes); + HASH(Store, offset); + HASH(Store, align); + PUSH(Store, ptr); + PUSH(Store, value); + break; + } + case Expression::Id::ConstId: { + HASH(Const, value.type); + HASH64(Const, value.getBits()); + break; + } + case Expression::Id::UnaryId: { + HASH(Unary, op); + PUSH(Unary, value); + break; + } + case Expression::Id::BinaryId: { + HASH(Binary, op); + PUSH(Binary, left); + PUSH(Binary, right); + break; + } + case Expression::Id::SelectId: { + PUSH(Select, ifTrue); + PUSH(Select, ifFalse); + PUSH(Select, condition); + break; + } + case Expression::Id::ReturnId: { + PUSH(Return, value); + break; + } + case Expression::Id::HostId: { + HASH(Host, op); + HASH_NAME(Host, nameOperand); + HASH(Host, operands.size()); + for (Index i = 0; i < curr->cast<Host>()->operands.size(); i++) { + PUSH(Host, operands[i]); + } + break; + } + case Expression::Id::NopId: { + break; + } + case Expression::Id::UnreachableId: { + break; + } + default: WASM_UNREACHABLE(); + } + #undef HASH + #undef PUSH + } + return digest; + } }; } // namespace wasm diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt index 0ec2686da..ab55dff85 100644 --- a/src/passes/CMakeLists.txt +++ b/src/passes/CMakeLists.txt @@ -3,6 +3,7 @@ SET(passes_SOURCES CoalesceLocals.cpp DeadCodeElimination.cpp DropReturnValues.cpp + DuplicateFunctionElimination.cpp LowerIfElse.cpp MergeBlocks.cpp Metrics.cpp diff --git a/src/passes/DuplicateFunctionElimination.cpp b/src/passes/DuplicateFunctionElimination.cpp new file mode 100644 index 000000000..593ddb7d9 --- /dev/null +++ b/src/passes/DuplicateFunctionElimination.cpp @@ -0,0 +1,179 @@ +/* + * Copyright 2016 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +// Removes duplicate functions. That can happen due to C++ templates, +// and also due to types being different at the source level, but +// identical when finally lowered into concrete wasm code. +// + +#include <wasm.h> +#include <pass.h> +#include <ast_utils.h> + +namespace wasm { + +struct FunctionHasher : public PostWalker<FunctionHasher, Visitor<FunctionHasher>> { + bool isFunctionParallel() { return true; } + + FunctionHasher* create() override { + auto* ret = new FunctionHasher; + ret->setOutput(output); + return ret; + } + + void setOutput(std::map<Function*, uint32_t>* output_) { + output = output_; + } + + void walk(Expression*& root) { + assert(digest == 0); + auto* func = getFunction(); + hash(func->getNumParams()); + for (auto type : func->params) hash(type); + hash(func->getNumVars()); + for (auto type : func->vars) hash(type); + hash(func->result); + hash64(func->type.is() ? uint64_t(func->type.str) : uint64_t(0)); + hash(ExpressionAnalyzer::hash(root)); + output->at(func) = digest; + } + +private: + std::map<Function*, uint32_t>* output; + uint32_t digest = 0; + + void hash(uint32_t hash) { + digest = rehash(digest, hash); + } + void hash64(uint64_t hash) { + digest = rehash(rehash(digest, hash >> 32), uint32_t(hash)); + }; +}; + +struct FunctionReplacer : public PostWalker<FunctionReplacer, Visitor<FunctionReplacer>> { + bool isFunctionParallel() { return true; } + + FunctionReplacer* create() override { + auto* ret = new FunctionReplacer; + ret->setReplacements(replacements); + return ret; + } + + void setReplacements(std::map<Name, Name>* replacements_) { + replacements = replacements_; + } + + void visitCall(Call* curr) { + auto iter = replacements->find(curr->target); + if (iter != replacements->end()) { + curr->target = iter->second; + } + } + +private: + std::map<Name, Name>* replacements; +}; + +struct DuplicateFunctionElimination : public Pass { + void run(PassRunner* runner, Module* module) override { + while (1) { + // Hash all the functions + hashes.clear(); + for (auto& func : module->functions) { + hashes[func.get()] = 0; // ensure an entry for each function - we must not modify the map shape in parallel, just the values + } + FunctionHasher hasher; + hasher.setOutput(&hashes); + hasher.startWalk(module); + // Find hash-equal groups + std::map<uint32_t, std::vector<Function*>> hashGroups; + for (auto& func : module->functions) { + hashGroups[hashes[func.get()]].push_back(func.get()); + } + // Find actually equal functions and prepare to replace them + std::map<Name, Name> replacements; + std::set<Name> duplicates; + for (auto& pair : hashGroups) { + auto& group = pair.second; + if (group.size() == 1) continue; + // pick a base for each group, and try to replace everyone else to it. TODO: multiple bases per hash group, for collisions + Function* base = group[0]; + for (auto* func : group) { + if (func != base && equal(func, base)) { + replacements[func->name] = base->name; + duplicates.insert(func->name); + } + } + } + // perform replacements + if (replacements.size() > 0) { + // remove the duplicates + auto& v = module->functions; + v.erase(std::remove_if(v.begin(), v.end(), [&](const std::unique_ptr<Function>& curr) { + return duplicates.count(curr->name) > 0; + }), v.end()); + module->updateFunctionsMap(); + // replace direct calls + FunctionReplacer replacer; + replacer.setReplacements(&replacements); + replacer.startWalk(module); + // replace in table + for (auto& name : module->table.names) { + auto iter = replacements.find(name); + if (iter != replacements.end()) { + name = iter->second; + } + } + // replace in start + if (module->start.is()) { + auto iter = replacements.find(module->start); + if (iter != replacements.end()) { + module->start = iter->second; + } + } + // replace in exports + for (auto& exp : module->exports) { + auto iter = replacements.find(exp->value); + if (iter != replacements.end()) { + exp->value = iter->second; + } + } + } else { + break; + } + } + } + +private: + std::map<Function*, uint32_t> hashes; + + bool equal(Function* left, Function* right) { + if (left->getNumParams() != right->getNumParams()) return false; + if (left->getNumVars() != right->getNumVars()) return false; + for (Index i = 0; i < left->getNumLocals(); i++) { + if (left->getLocalType(i) != right->getLocalType(i)) return false; + } + if (left->result != right->result) return false; + if (left->type != right->type) return false; + return ExpressionAnalyzer::equal(left->body, right->body); + } +}; + +static RegisterPass<DuplicateFunctionElimination> registerPass("duplicate-function-elimination", "removes duplicate functions"); + +} // namespace wasm + diff --git a/src/passes/Metrics.cpp b/src/passes/Metrics.cpp index a6d51adbf..a01ead073 100644 --- a/src/passes/Metrics.cpp +++ b/src/passes/Metrics.cpp @@ -54,6 +54,10 @@ struct Metrics : public WalkerPass<PostWalker<Metrics, UnifiedExpressionVisitor< } keys.push_back("[vars]"); counts["[vars]"] = vars; + // add functions + keys.push_back("[funcs]"); + counts["[funcs]"] = module->functions.size(); + // sort sort(keys.begin(), keys.end(), [](const char* a, const char* b) -> bool { return strcmp(b, a) > 0; }); diff --git a/src/passes/ReorderLocals.cpp b/src/passes/ReorderLocals.cpp index 0e626e0a2..eea8ea962 100644 --- a/src/passes/ReorderLocals.cpp +++ b/src/passes/ReorderLocals.cpp @@ -17,7 +17,8 @@ // // Sorts locals by access frequency. // - +// Secondarily, sort by first appearance. This canonicalizes the order. +// #include <memory> @@ -29,7 +30,8 @@ namespace wasm { struct ReorderLocals : public WalkerPass<PostWalker<ReorderLocals, Visitor<ReorderLocals>>> { bool isFunctionParallel() { return true; } - std::map<Index, uint32_t> counts; + std::map<Index, Index> counts; // local => times it is used + std::map<Index, Index> firstUses; // local => index in the list of which local is first seen void visitFunction(Function *curr) { Index num = curr->getNumLocals(); @@ -44,10 +46,11 @@ struct ReorderLocals : public WalkerPass<PostWalker<ReorderLocals, Visitor<Reord if (curr->isParam(b) && curr->isParam(a)) { return a < b; } - if (this->counts[a] == this->counts[b]) { - return a < b; + if (counts[a] == counts[b]) { + if (counts[a] == 0) return a < b; + return firstUses[a] < firstUses[b]; } - return this->counts[a] > this->counts[b]; + return counts[a] > counts[b]; }); // sorting left params in front, perhaps slightly reordered. verify and fix. for (size_t i = 0; i < curr->params.size(); i++) { @@ -116,10 +119,16 @@ struct ReorderLocals : public WalkerPass<PostWalker<ReorderLocals, Visitor<Reord void visitGetLocal(GetLocal *curr) { counts[curr->index]++; + if (firstUses.count(curr->index) == 0) { + firstUses[curr->index] = firstUses.size(); + } } void visitSetLocal(SetLocal *curr) { counts[curr->index]++; + if (firstUses.count(curr->index) == 0) { + firstUses[curr->index] = firstUses.size(); + } } }; diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index b358243f7..ca9f477a3 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -58,6 +58,7 @@ std::string PassRegistry::getPassDescription(std::string name) { // PassRunner void PassRunner::addDefaultOptimizationPasses() { + add("duplicate-function-elimination"); add("dce"); add("remove-unused-brs"); add("remove-unused-names"); @@ -70,6 +71,7 @@ void PassRunner::addDefaultOptimizationPasses() { add("merge-blocks"); add("optimize-instructions"); add("vacuum"); // should not be needed, last few passes do not create garbage, but just to be safe + add("duplicate-function-elimination"); // optimizations show more functions as duplicate } void PassRunner::run() { diff --git a/src/support/hash.h b/src/support/hash.h new file mode 100644 index 000000000..e2d393d25 --- /dev/null +++ b/src/support/hash.h @@ -0,0 +1,39 @@ +/* + * Copyright 2015 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_support_hash_h +#define wasm_support_hash_h + +#include <stdint.h> + +namespace wasm { + +inline uint32_t rehash(uint32_t x, uint32_t y) { // see http://www.cse.yorku.ca/~oz/hash.html + uint32_t hash = 5381; + while (x) { + hash = ((hash << 5) + hash) ^ (x & 0xff); + x >>= 8; + } + while (y) { + hash = ((hash << 5) + hash) ^ (y & 0xff); + y >>= 8; + } + return hash; +} + +} // namespace wasm + +#endif // wasm_support_hash_h diff --git a/src/wasm.h b/src/wasm.h index e8f82f47c..f59e4368f 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -250,6 +250,14 @@ private: } } + int64_t getBits() { + switch (type) { + case WasmType::i32: case WasmType::f32: return i32; + case WasmType::i64: case WasmType::f64: return i64; + default: abort(); + } + } + bool operator==(const Literal& other) const { if (type != other.type) return false; switch (type) { @@ -262,6 +270,10 @@ private: } } + bool operator!=(const Literal& other) const { + return !(*this == other); + } + static uint32_t NaNPayload(float f) { assert(std::isnan(f) && "expected a NaN"); // SEEEEEEE EFFFFFFF FFFFFFFF FFFFFFFF |