summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/ir/type-updating.h28
-rw-r--r--src/passes/CMakeLists.txt1
-rw-r--r--src/passes/SignaturePruning.cpp200
-rw-r--r--src/passes/SignatureRefining.cpp20
-rw-r--r--src/passes/param-utils.cpp15
-rw-r--r--src/passes/pass.cpp4
-rw-r--r--src/passes/passes.h1
-rw-r--r--src/tools/wasm-reduce.cpp1
8 files changed, 248 insertions, 22 deletions
diff --git a/src/ir/type-updating.h b/src/ir/type-updating.h
index 85134da67..d1c4af6bc 100644
--- a/src/ir/type-updating.h
+++ b/src/ir/type-updating.h
@@ -351,6 +351,34 @@ public:
// things.
Type getTempType(Type type);
+ using SignatureUpdates = std::unordered_map<HeapType, Signature>;
+
+ // Helper for the repeating pattern of just updating Signature types using a
+ // map of old heap type => new Signature.
+ static void updateSignatures(const SignatureUpdates& updates, Module& wasm) {
+ if (updates.empty()) {
+ return;
+ }
+
+ class SignatureRewriter : public GlobalTypeRewriter {
+ const SignatureUpdates& updates;
+
+ public:
+ SignatureRewriter(Module& wasm, const SignatureUpdates& updates)
+ : GlobalTypeRewriter(wasm), updates(updates) {
+ update();
+ }
+
+ void modifySignature(HeapType oldSignatureType, Signature& sig) override {
+ auto iter = updates.find(oldSignatureType);
+ if (iter != updates.end()) {
+ sig.params = getTempType(iter->second.params);
+ sig.results = getTempType(iter->second.results);
+ }
+ }
+ } rewriter(wasm, updates);
+ }
+
private:
TypeBuilder typeBuilder;
diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt
index ac3971690..c83e95d9e 100644
--- a/src/passes/CMakeLists.txt
+++ b/src/passes/CMakeLists.txt
@@ -75,6 +75,7 @@ set(passes_SOURCES
RoundTrip.cpp
SetGlobals.cpp
StackIR.cpp
+ SignaturePruning.cpp
SignatureRefining.cpp
Strip.cpp
StripTargetFeatures.cpp
diff --git a/src/passes/SignaturePruning.cpp b/src/passes/SignaturePruning.cpp
new file mode 100644
index 000000000..d41c04d57
--- /dev/null
+++ b/src/passes/SignaturePruning.cpp
@@ -0,0 +1,200 @@
+/*
+ * Copyright 2022 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//
+// Remove params from signature/function types where possible.
+//
+// This differs from DeadArgumentElimination in that DAE will look at each
+// function by itself, and cannot handle indirectly-called functions. This pass
+// looks at each heap type at a time, and if all functions with a heap type do
+// not use a particular param, will remove the param.
+//
+
+#include "ir/find_all.h"
+#include "ir/lubs.h"
+#include "ir/module-utils.h"
+#include "ir/type-updating.h"
+#include "param-utils.h"
+#include "pass.h"
+#include "support/sorted_vector.h"
+#include "wasm-type.h"
+#include "wasm.h"
+
+namespace wasm {
+
+namespace {
+
+struct SignaturePruning : public Pass {
+ // Maps each heap type to the possible pruned heap type. We will fill this
+ // during analysis and then use it while doing an update of the types. If a
+ // type has no improvement that we can find, it will not appear in this map.
+ std::unordered_map<HeapType, Signature> newSignatures;
+
+ void run(PassRunner* runner, Module* module) override {
+ if (getTypeSystem() != TypeSystem::Nominal) {
+ Fatal() << "SignaturePruning requires nominal typing";
+ }
+
+ if (!module->tables.empty()) {
+ // When there are tables we must also take their types into account, which
+ // would require us to take call_indirect, element segments, etc. into
+ // account. For now, do nothing if there are tables.
+ // TODO
+ return;
+ }
+
+ // First, find all the information we need. Start by collecting inside each
+ // function in parallel.
+
+ struct Info {
+ std::vector<Call*> calls;
+ std::vector<CallRef*> callRefs;
+
+ std::unordered_set<Index> usedParams;
+
+ void markUnoptimizable(Function* func) {
+ // To prevent any optimization, mark all the params as if there were
+ // used.
+ for (Index i = 0; i < func->getNumParams(); i++) {
+ usedParams.insert(i);
+ }
+ }
+ };
+
+ ModuleUtils::ParallelFunctionAnalysis<Info> analysis(
+ *module, [&](Function* func, Info& info) {
+ if (func->imported()) {
+ // Imports cannot be modified.
+ info.markUnoptimizable(func);
+ return;
+ }
+
+ info.calls = std::move(FindAll<Call>(func->body).list);
+ info.callRefs = std::move(FindAll<CallRef>(func->body).list);
+ info.usedParams = ParamUtils::getUsedParams(func);
+ });
+
+ // A map of types to all the information combined over all the functions
+ // with that type.
+ std::unordered_map<HeapType, Info> allInfo;
+
+ // Map heap types to all functions with that type.
+ std::unordered_map<HeapType, std::vector<Function*>> sigFuncs;
+
+ // Combine all the information we gathered into that map.
+ for (auto& [func, info] : analysis.map) {
+ // For direct calls, add each call to the type of the function being
+ // called.
+ for (auto* call : info.calls) {
+ allInfo[module->getFunction(call->target)->type].calls.push_back(call);
+ }
+
+ // For indirect calls, add each call_ref to the type the call_ref uses.
+ for (auto* callRef : info.callRefs) {
+ auto calledType = callRef->target->type;
+ if (calledType != Type::unreachable) {
+ allInfo[calledType.getHeapType()].callRefs.push_back(callRef);
+ }
+ }
+
+ // A parameter used in this function is used in the heap type - just one
+ // function is enough to prevent the parameter from being removed.
+ auto& allUsedParams = allInfo[func->type].usedParams;
+ for (auto index : info.usedParams) {
+ allUsedParams.insert(index);
+ }
+ sigFuncs[func->type].push_back(func);
+ }
+
+ // Exported functions cannot be modified.
+ for (auto& exp : module->exports) {
+ if (exp->kind == ExternalKind::Function) {
+ auto* func = module->getFunction(exp->value);
+ allInfo[func->type].markUnoptimizable(func);
+ }
+ }
+
+ // Find parameters to prune.
+ for (auto& [type, funcs] : sigFuncs) {
+ auto sig = type.getSignature();
+ auto& info = allInfo[type];
+ auto numParams = sig.params.size();
+ if (info.usedParams.size() == numParams) {
+ // All parameters are used, give up on this one.
+ continue;
+ }
+
+ // We found possible work! Find the specific params that are unused try to
+ // prune them.
+ SortedVector unusedParams;
+ for (Index i = 0; i < numParams; i++) {
+ if (info.usedParams.count(i) == 0) {
+ unusedParams.insert(i);
+ }
+ }
+
+ auto oldParams = sig.params;
+ auto removedIndexes = ParamUtils::removeParameters(
+ funcs, unusedParams, info.calls, info.callRefs, module, runner);
+ if (removedIndexes.empty()) {
+ continue;
+ }
+
+ // Success! Update the types.
+ std::vector<Type> newParams;
+ for (Index i = 0; i < numParams; i++) {
+ if (!removedIndexes.has(i)) {
+ newParams.push_back(oldParams[i]);
+ }
+ }
+
+ // Create a new signature. When the TypeRewriter operates below it will
+ // modify the existing heap type in place to change its signature to this
+ // one (which preserves identity, that is, even if after pruning the new
+ // signature is structurally identical to another one, it will remain
+ // nominally different from those).
+ newSignatures[type] = Signature(Type(newParams), sig.results);
+
+ // removeParameters() updates the type as it goes, but in this pass we
+ // need the type to match the other locations, nominally. That is, we need
+ // all the functions of a particular type to still have the same type
+ // after this operation, and that must be the exact same type at the
+ // relevant call_refs and so forth. The TypeRewriter below will do the
+ // right thing as it rewrites everything all at once, so we do not want
+ // the type to be modified by removeParameters(), and so we undo the type
+ // it made.
+ //
+ // Note that we cannot just ask removeParameters() to not update the type,
+ // as it adds a new local there, whose index depends on the type (which
+ // contains the # of parameters, and that determine where non-parameter
+ // local indexes begin). Rather than have it update the type and then undo
+ // that, which would add more complexity in that method, undo the change
+ // here.
+ for (auto* func : funcs) {
+ func->type = type;
+ }
+ }
+
+ // Rewrite the types.
+ GlobalTypeRewriter::updateSignatures(newSignatures, *module);
+ }
+};
+
+} // anonymous namespace
+
+Pass* createSignaturePruningPass() { return new SignaturePruning(); }
+
+} // namespace wasm
diff --git a/src/passes/SignatureRefining.cpp b/src/passes/SignatureRefining.cpp
index 623a393c1..37ca091df 100644
--- a/src/passes/SignatureRefining.cpp
+++ b/src/passes/SignatureRefining.cpp
@@ -35,8 +35,6 @@
#include "wasm-type.h"
#include "wasm.h"
-using namespace std;
-
namespace wasm {
namespace {
@@ -227,23 +225,7 @@ struct SignatureRefining : public Pass {
CodeUpdater(*this, *module).run(runner, module);
// Rewrite the types.
- class TypeRewriter : public GlobalTypeRewriter {
- SignatureRefining& parent;
-
- public:
- TypeRewriter(Module& wasm, SignatureRefining& parent)
- : GlobalTypeRewriter(wasm), parent(parent) {}
-
- void modifySignature(HeapType oldSignatureType, Signature& sig) override {
- auto iter = parent.newSignatures.find(oldSignatureType);
- if (iter != parent.newSignatures.end()) {
- sig.params = getTempType(iter->second.params);
- sig.results = getTempType(iter->second.results);
- }
- }
- };
-
- TypeRewriter(*module, *this).update();
+ GlobalTypeRewriter::updateSignatures(newSignatures, *module);
if (refinedResults) {
// After return types change we need to propagate.
diff --git a/src/passes/param-utils.cpp b/src/passes/param-utils.cpp
index ded96a826..ae641fd26 100644
--- a/src/passes/param-utils.cpp
+++ b/src/passes/param-utils.cpp
@@ -62,15 +62,24 @@ bool removeParameter(const std::vector<Function*>& funcs,
// Check if none of the calls has a param with side effects that we cannot
// remove (as if we can remove them, we will simply do that when we remove the
// parameter). Note: flattening the IR beforehand can help here.
+ auto hasBadEffects = [&](ExpressionList& operands) {
+ return EffectAnalyzer(runner->options, *module, operands[index])
+ .hasUnremovableSideEffects();
+ };
bool callParamsAreValid =
std::none_of(calls.begin(), calls.end(), [&](Call* call) {
- auto* operand = call->operands[index];
- return EffectAnalyzer(runner->options, *module, operand)
- .hasUnremovableSideEffects();
+ return hasBadEffects(call->operands);
});
if (!callParamsAreValid) {
return false;
}
+ bool callRefParamsAreValid =
+ std::none_of(callRefs.begin(), callRefs.end(), [&](CallRef* call) {
+ return hasBadEffects(call->operands);
+ });
+ if (!callRefParamsAreValid) {
+ return false;
+ }
// The type must be valid for us to handle as a local (since we
// replace the parameter with a local).
diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp
index 3ac4485e3..04947ae07 100644
--- a/src/passes/pass.cpp
+++ b/src/passes/pass.cpp
@@ -353,6 +353,9 @@ void PassRegistry::registerPasses() {
registerPass("set-globals",
"sets specified globals to specified values",
createSetGlobalsPass);
+ registerPass("signature-pruning",
+ "remove params from function signature types where possible",
+ createSignaturePruningPass);
registerPass("signature-refining",
"apply more specific subtypes to signature types where possible",
createSignatureRefiningPass);
@@ -549,6 +552,7 @@ void PassRunner::addDefaultGlobalOptimizationPrePasses() {
if (wasm->features.hasGC() && getTypeSystem() == TypeSystem::Nominal &&
options.optimizeLevel >= 2) {
addIfNoDWARFIssues("type-refining");
+ addIfNoDWARFIssues("signature-pruning");
addIfNoDWARFIssues("signature-refining");
addIfNoDWARFIssues("global-refining");
// Global type optimization can remove fields that are not needed, which can
diff --git a/src/passes/passes.h b/src/passes/passes.h
index a703e872f..d7a6f9989 100644
--- a/src/passes/passes.h
+++ b/src/passes/passes.h
@@ -112,6 +112,7 @@ Pass* createRedundantSetEliminationPass();
Pass* createRoundTripPass();
Pass* createSafeHeapPass();
Pass* createSetGlobalsPass();
+Pass* createSignaturePruningPass();
Pass* createSignatureRefiningPass();
Pass* createSimplifyLocalsPass();
Pass* createSimplifyGlobalsPass();
diff --git a/src/tools/wasm-reduce.cpp b/src/tools/wasm-reduce.cpp
index 2ca41339e..3b29571ba 100644
--- a/src/tools/wasm-reduce.cpp
+++ b/src/tools/wasm-reduce.cpp
@@ -279,6 +279,7 @@ struct Reducer
"--remove-unused-nonfunction-module-elements",
"--reorder-functions",
"--reorder-locals",
+ // TODO: signature* passes
"--simplify-globals",
"--simplify-locals --vacuum",
"--strip",