summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/ast_utils.h28
-rw-r--r--src/passes/CMakeLists.txt2
-rw-r--r--src/passes/DuplicateFunctionElimination.cpp2
-rw-r--r--src/passes/LegalizeJSInterface.cpp31
-rw-r--r--src/passes/RemoveUnusedFunctions.cpp65
-rw-r--r--src/passes/RemoveUnusedModuleElements.cpp155
-rw-r--r--src/passes/pass.cpp6
-rw-r--r--src/passes/passes.h2
-rw-r--r--src/wasm.h22
9 files changed, 198 insertions, 115 deletions
diff --git a/src/ast_utils.h b/src/ast_utils.h
index f10fb40eb..2565ee24a 100644
--- a/src/ast_utils.h
+++ b/src/ast_utils.h
@@ -68,34 +68,6 @@ struct BreakSeeker : public PostWalker<BreakSeeker, Visitor<BreakSeeker>> {
}
};
-// Finds all functions that are reachable via direct calls.
-
-struct DirectCallGraphAnalyzer : public PostWalker<DirectCallGraphAnalyzer, Visitor<DirectCallGraphAnalyzer>> {
- Module *module;
- std::vector<Function*> queue;
- std::unordered_set<Function*> reachable;
-
- DirectCallGraphAnalyzer(Module* module, const std::vector<Function*>& root) : module(module) {
- for (auto* curr : root) {
- queue.push_back(curr);
- }
- while (queue.size()) {
- auto* curr = queue.back();
- queue.pop_back();
- if (reachable.count(curr) == 0) {
- reachable.insert(curr);
- walk(curr->body);
- }
- }
- }
- void visitCall(Call *curr) {
- auto* target = module->getFunction(curr->target);
- if (reachable.count(target) == 0) {
- queue.push_back(target);
- }
- }
-};
-
// Look for side effects, including control flow
// TODO: optimize
diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt
index e47a7d892..9db1c66ae 100644
--- a/src/passes/CMakeLists.txt
+++ b/src/passes/CMakeLists.txt
@@ -22,7 +22,7 @@ SET(passes_SOURCES
RemoveMemory.cpp
RemoveUnusedBrs.cpp
RemoveUnusedNames.cpp
- RemoveUnusedFunctions.cpp
+ RemoveUnusedModuleElements.cpp
ReorderLocals.cpp
ReorderFunctions.cpp
SimplifyLocals.cpp
diff --git a/src/passes/DuplicateFunctionElimination.cpp b/src/passes/DuplicateFunctionElimination.cpp
index cfe2d8565..8e8342729 100644
--- a/src/passes/DuplicateFunctionElimination.cpp
+++ b/src/passes/DuplicateFunctionElimination.cpp
@@ -127,7 +127,7 @@ struct DuplicateFunctionElimination : public Pass {
v.erase(std::remove_if(v.begin(), v.end(), [&](const std::unique_ptr<Function>& curr) {
return duplicates.count(curr->name) > 0;
}), v.end());
- module->updateFunctionsMap();
+ module->updateMaps();
// replace direct calls
PassRunner replacerRunner(module);
replacerRunner.add<FunctionReplacer>(&replacements);
diff --git a/src/passes/LegalizeJSInterface.cpp b/src/passes/LegalizeJSInterface.cpp
index 3819fcf72..33bba2112 100644
--- a/src/passes/LegalizeJSInterface.cpp
+++ b/src/passes/LegalizeJSInterface.cpp
@@ -126,14 +126,11 @@ private:
auto index = builder.addVar(legal, Name(), i64);
auto* block = builder.makeBlock();
block->list.push_back(builder.makeSetLocal(index, call));
- if (module->checkGlobal(TEMP_RET_0)) {
- block->list.push_back(builder.makeSetGlobal(
- TEMP_RET_0,
- I64Utilities::getI64High(builder, index)
- ));
- } else {
- block->list.push_back(builder.makeUnreachable()); // no way to emit the high bits :(
- }
+ ensureTempRet0(module);
+ block->list.push_back(builder.makeSetGlobal(
+ TEMP_RET_0,
+ I64Utilities::getI64High(builder, index)
+ ));
block->list.push_back(I64Utilities::getI64Low(builder, index));
block->finalize();
legal->body = block;
@@ -183,11 +180,8 @@ private:
if (im->functionType->result == i64) {
call->type = i32;
Expression* get;
- if (module->checkGlobal(TEMP_RET_0)) {
- get = builder.makeGetGlobal(TEMP_RET_0, i32);
- } else {
- get = builder.makeUnreachable(); // no way to emit the high bits :(
- }
+ ensureTempRet0(module);
+ get = builder.makeGetGlobal(TEMP_RET_0, i32);
func->body = I64Utilities::recreateI64(builder, call, get);
type->result = i32;
} else {
@@ -201,6 +195,17 @@ private:
module->addFunctionType(type);
return legal;
}
+
+ void ensureTempRet0(Module* module) {
+ if (!module->checkGlobal(TEMP_RET_0)) {
+ Global* global = new Global;
+ global->name = TEMP_RET_0;
+ global->type = i32;
+ global->init = module->allocator.alloc<Const>()->set(Literal(int32_t(0)));
+ global->mutable_ = true;
+ module->addGlobal(global);
+ }
+ }
};
Pass *createLegalizeJSInterfacePass() {
diff --git a/src/passes/RemoveUnusedFunctions.cpp b/src/passes/RemoveUnusedFunctions.cpp
deleted file mode 100644
index ec9e271b7..000000000
--- a/src/passes/RemoveUnusedFunctions.cpp
+++ /dev/null
@@ -1,65 +0,0 @@
-/*
- * Copyright 2016 WebAssembly Community Group participants
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-//
-// Removes functions that are never used.
-//
-
-
-#include <memory>
-
-#include "wasm.h"
-#include "pass.h"
-#include "ast_utils.h"
-
-namespace wasm {
-
-struct RemoveUnusedFunctions : public Pass {
- void run(PassRunner* runner, Module* module) override {
- std::vector<Function*> root;
- // Module start is a root.
- if (module->start.is()) {
- root.push_back(module->getFunction(module->start));
- }
- // Exports are roots.
- for (auto& curr : module->exports) {
- if (curr->kind == ExternalKind::Function) {
- root.push_back(module->getFunction(curr->value));
- }
- }
- // For now, all functions that can be called indirectly are marked as roots.
- for (auto& segment : module->table.segments) {
- for (auto& curr : segment.data) {
- root.push_back(module->getFunction(curr));
- }
- }
- // Compute function reachability starting from the root set.
- DirectCallGraphAnalyzer analyzer(module, root);
- // Remove unreachable functions.
- auto& v = module->functions;
- v.erase(std::remove_if(v.begin(), v.end(), [&](const std::unique_ptr<Function>& curr) {
- return analyzer.reachable.count(curr.get()) == 0;
- }), v.end());
- assert(module->functions.size() == analyzer.reachable.size());
- module->updateFunctionsMap();
- }
-};
-
-Pass *createRemoveUnusedFunctionsPass() {
- return new RemoveUnusedFunctions();
-}
-
-} // namespace wasm
diff --git a/src/passes/RemoveUnusedModuleElements.cpp b/src/passes/RemoveUnusedModuleElements.cpp
new file mode 100644
index 000000000..cf9741961
--- /dev/null
+++ b/src/passes/RemoveUnusedModuleElements.cpp
@@ -0,0 +1,155 @@
+/*
+ * Copyright 2016 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//
+// Removes module elements that are are never used: functions and globals,
+// which may be imported or not.
+//
+
+
+#include <memory>
+
+#include "wasm.h"
+#include "pass.h"
+#include "ast_utils.h"
+
+namespace wasm {
+
+enum class ModuleElementKind {
+ Function,
+ Global
+};
+
+typedef std::pair<ModuleElementKind, Name> ModuleElement;
+
+// Finds reachabilities
+
+struct ReachabilityAnalyzer : public PostWalker<ReachabilityAnalyzer, Visitor<ReachabilityAnalyzer>> {
+ Module* module;
+ std::vector<ModuleElement> queue;
+ std::set<ModuleElement> reachable;
+
+ ReachabilityAnalyzer(Module* module, const std::vector<ModuleElement>& roots) : module(module) {
+ queue = roots;
+ // Globals used in memory/table init expressions are also roots
+ for (auto& segment : module->memory.segments) {
+ walk(segment.offset);
+ }
+ for (auto& segment : module->table.segments) {
+ walk(segment.offset);
+ }
+ // main loop
+ while (queue.size()) {
+ auto& curr = queue.back();
+ queue.pop_back();
+ if (reachable.count(curr) == 0) {
+ reachable.insert(curr);
+ if (curr.first == ModuleElementKind::Function) {
+ // if not an import, walk it
+ auto* func = module->checkFunction(curr.second);
+ if (func) {
+ walk(func->body);
+ }
+ } else {
+ // if not imported, it has an init expression we need to walk
+ auto* glob = module->checkGlobal(curr.second);
+ if (glob) {
+ walk(glob->init);
+ }
+ }
+ }
+ }
+ }
+
+ void visitCall(Call* curr) {
+ if (reachable.count(ModuleElement(ModuleElementKind::Function, curr->target)) == 0) {
+ queue.emplace_back(ModuleElementKind::Function, curr->target);
+ }
+ }
+ void visitCallImport(CallImport* curr) {
+ if (reachable.count(ModuleElement(ModuleElementKind::Function, curr->target)) == 0) {
+ queue.emplace_back(ModuleElementKind::Function, curr->target);
+ }
+ }
+
+ void visitGetGlobal(GetGlobal* curr) {
+ if (reachable.count(ModuleElement(ModuleElementKind::Global, curr->name)) == 0) {
+ queue.emplace_back(ModuleElementKind::Global, curr->name);
+ }
+ }
+ void visitSetGlobal(SetGlobal* curr) {
+ if (reachable.count(ModuleElement(ModuleElementKind::Global, curr->name)) == 0) {
+ queue.emplace_back(ModuleElementKind::Global, curr->name);
+ }
+ }
+};
+
+struct RemoveUnusedModuleElements : public Pass {
+ void run(PassRunner* runner, Module* module) override {
+ std::vector<ModuleElement> roots;
+ // Module start is a root.
+ if (module->start.is()) {
+ roots.emplace_back(ModuleElementKind::Function, module->start);
+ }
+ // Exports are roots.
+ for (auto& curr : module->exports) {
+ if (curr->kind == ExternalKind::Function) {
+ roots.emplace_back(ModuleElementKind::Function, curr->value);
+ } else if (curr->kind == ExternalKind::Global) {
+ roots.emplace_back(ModuleElementKind::Global, curr->value);
+ }
+ }
+ // For now, all functions that can be called indirectly are marked as roots.
+ for (auto& segment : module->table.segments) {
+ for (auto& curr : segment.data) {
+ roots.emplace_back(ModuleElementKind::Function, curr);
+ }
+ }
+ // Compute reachability starting from the root set.
+ ReachabilityAnalyzer analyzer(module, roots);
+ // Remove unreachable elements.
+ {
+ auto& v = module->functions;
+ v.erase(std::remove_if(v.begin(), v.end(), [&](const std::unique_ptr<Function>& curr) {
+ return analyzer.reachable.count(ModuleElement(ModuleElementKind::Function, curr->name)) == 0;
+ }), v.end());
+ }
+ {
+ auto& v = module->globals;
+ v.erase(std::remove_if(v.begin(), v.end(), [&](const std::unique_ptr<Global>& curr) {
+ return analyzer.reachable.count(ModuleElement(ModuleElementKind::Global, curr->name)) == 0;
+ }), v.end());
+ }
+ {
+ auto& v = module->imports;
+ v.erase(std::remove_if(v.begin(), v.end(), [&](const std::unique_ptr<Import>& curr) {
+ if (curr->kind == ExternalKind::Function) {
+ return analyzer.reachable.count(ModuleElement(ModuleElementKind::Function, curr->name)) == 0;
+ } else if (curr->kind == ExternalKind::Global) {
+ return analyzer.reachable.count(ModuleElement(ModuleElementKind::Global, curr->name)) == 0;
+ }
+ return false;
+ }), v.end());
+ }
+ module->updateMaps();
+ }
+};
+
+Pass* createRemoveUnusedModuleElementsPass() {
+ return new RemoveUnusedModuleElements();
+}
+
+} // namespace wasm
diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp
index 20e002f4b..32b596eef 100644
--- a/src/passes/pass.cpp
+++ b/src/passes/pass.cpp
@@ -86,7 +86,7 @@ void PassRegistry::registerPasses() {
registerPass("remove-imports", "removes imports and replaces them with nops", createRemoveImportsPass);
registerPass("remove-memory", "removes memory segments", createRemoveMemoryPass);
registerPass("remove-unused-brs", "removes breaks from locations that are not needed", createRemoveUnusedBrsPass);
- registerPass("remove-unused-functions", "removes unused functions", createRemoveUnusedFunctionsPass);
+ registerPass("remove-unused-module-elements", "removes unused module elements", createRemoveUnusedModuleElementsPass);
registerPass("remove-unused-names", "removes names from locations that are never branched to", createRemoveUnusedNamesPass);
registerPass("reorder-functions", "sorts functions by access frequency", createReorderFunctionsPass);
registerPass("reorder-locals", "sorts locals by access frequency", createReorderLocalsPass);
@@ -103,7 +103,7 @@ void PassRunner::addDefaultOptimizationPasses() {
add("duplicate-function-elimination");
addDefaultFunctionOptimizationPasses();
add("duplicate-function-elimination"); // optimizations show more functions as duplicate
- add("remove-unused-functions");
+ add("remove-unused-module-elements");
add("memory-packing");
}
@@ -133,7 +133,7 @@ void PassRunner::addDefaultFunctionOptimizationPasses() {
void PassRunner::addDefaultGlobalOptimizationPasses() {
add("duplicate-function-elimination");
- add("remove-unused-functions");
+ add("remove-unused-module-elements");
add("memory-packing");
}
diff --git a/src/passes/passes.h b/src/passes/passes.h
index 98f99654e..cbfc48327 100644
--- a/src/passes/passes.h
+++ b/src/passes/passes.h
@@ -46,7 +46,7 @@ Pass *createRelooperJumpThreadingPass();
Pass *createRemoveImportsPass();
Pass *createRemoveMemoryPass();
Pass *createRemoveUnusedBrsPass();
-Pass *createRemoveUnusedFunctionsPass();
+Pass *createRemoveUnusedModuleElementsPass();
Pass *createRemoveUnusedNamesPass();
Pass *createReorderFunctionsPass();
Pass *createReorderLocalsPass();
diff --git a/src/wasm.h b/src/wasm.h
index c1ef75ea9..75d6a174c 100644
--- a/src/wasm.h
+++ b/src/wasm.h
@@ -1609,10 +1609,26 @@ public:
}
// TODO: remove* for other elements
- void updateFunctionsMap() {
+ void updateMaps() {
functionsMap.clear();
- for (auto& func : functions) {
- functionsMap[func->name] = func.get();
+ for (auto& curr : functions) {
+ functionsMap[curr->name] = curr.get();
+ }
+ functionTypesMap.clear();
+ for (auto& curr : functionTypes) {
+ functionTypesMap[curr->name] = curr.get();
+ }
+ importsMap.clear();
+ for (auto& curr : imports) {
+ importsMap[curr->name] = curr.get();
+ }
+ exportsMap.clear();
+ for (auto& curr : exports) {
+ exportsMap[curr->name] = curr.get();
+ }
+ globalsMap.clear();
+ for (auto& curr : globals) {
+ globalsMap[curr->name] = curr.get();
}
}
};