summaryrefslogtreecommitdiff
path: root/src/passes/SignaturePruning.cpp
diff options
context:
space:
mode:
authorAlon Zakai <azakai@google.com>2022-03-25 16:33:53 -0700
committerGitHub <noreply@github.com>2022-03-25 16:33:53 -0700
commit11932cc31e88d3d368714fcca43df979f7694bd1 (patch)
tree889faa7d6c6e2e261970ea776e3183202e9d6cba /src/passes/SignaturePruning.cpp
parent3a1953a1f417eb2f588eeb35bf26a3df6ea8f8e1 (diff)
downloadbinaryen-11932cc31e88d3d368714fcca43df979f7694bd1.tar.gz
binaryen-11932cc31e88d3d368714fcca43df979f7694bd1.tar.bz2
binaryen-11932cc31e88d3d368714fcca43df979f7694bd1.zip
[Wasm GC] Signature Pruning (#4545)
This adds a new signature-pruning pass that prunes parameters from signature types where those parameters are never used in any function that has that type. This is similar to DeadArgumentElimination but works on a set of functions, and it can handle indirect calls. Also move a little code from SignatureRefining into a shared place to avoid duplication of logic to update signature types. This pattern happens in j2wasm code, for example if all method functions for some virtual method just return a constant and do not use the this pointer.
Diffstat (limited to 'src/passes/SignaturePruning.cpp')
-rw-r--r--src/passes/SignaturePruning.cpp200
1 files changed, 200 insertions, 0 deletions
diff --git a/src/passes/SignaturePruning.cpp b/src/passes/SignaturePruning.cpp
new file mode 100644
index 000000000..d41c04d57
--- /dev/null
+++ b/src/passes/SignaturePruning.cpp
@@ -0,0 +1,200 @@
+/*
+ * Copyright 2022 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//
+// Remove params from signature/function types where possible.
+//
+// This differs from DeadArgumentElimination in that DAE will look at each
+// function by itself, and cannot handle indirectly-called functions. This pass
+// looks at each heap type at a time, and if all functions with a heap type do
+// not use a particular param, will remove the param.
+//
+
+#include "ir/find_all.h"
+#include "ir/lubs.h"
+#include "ir/module-utils.h"
+#include "ir/type-updating.h"
+#include "param-utils.h"
+#include "pass.h"
+#include "support/sorted_vector.h"
+#include "wasm-type.h"
+#include "wasm.h"
+
+namespace wasm {
+
+namespace {
+
+struct SignaturePruning : public Pass {
+ // Maps each heap type to the possible pruned heap type. We will fill this
+ // during analysis and then use it while doing an update of the types. If a
+ // type has no improvement that we can find, it will not appear in this map.
+ std::unordered_map<HeapType, Signature> newSignatures;
+
+ void run(PassRunner* runner, Module* module) override {
+ if (getTypeSystem() != TypeSystem::Nominal) {
+ Fatal() << "SignaturePruning requires nominal typing";
+ }
+
+ if (!module->tables.empty()) {
+ // When there are tables we must also take their types into account, which
+ // would require us to take call_indirect, element segments, etc. into
+ // account. For now, do nothing if there are tables.
+ // TODO
+ return;
+ }
+
+ // First, find all the information we need. Start by collecting inside each
+ // function in parallel.
+
+ struct Info {
+ std::vector<Call*> calls;
+ std::vector<CallRef*> callRefs;
+
+ std::unordered_set<Index> usedParams;
+
+ void markUnoptimizable(Function* func) {
+ // To prevent any optimization, mark all the params as if there were
+ // used.
+ for (Index i = 0; i < func->getNumParams(); i++) {
+ usedParams.insert(i);
+ }
+ }
+ };
+
+ ModuleUtils::ParallelFunctionAnalysis<Info> analysis(
+ *module, [&](Function* func, Info& info) {
+ if (func->imported()) {
+ // Imports cannot be modified.
+ info.markUnoptimizable(func);
+ return;
+ }
+
+ info.calls = std::move(FindAll<Call>(func->body).list);
+ info.callRefs = std::move(FindAll<CallRef>(func->body).list);
+ info.usedParams = ParamUtils::getUsedParams(func);
+ });
+
+ // A map of types to all the information combined over all the functions
+ // with that type.
+ std::unordered_map<HeapType, Info> allInfo;
+
+ // Map heap types to all functions with that type.
+ std::unordered_map<HeapType, std::vector<Function*>> sigFuncs;
+
+ // Combine all the information we gathered into that map.
+ for (auto& [func, info] : analysis.map) {
+ // For direct calls, add each call to the type of the function being
+ // called.
+ for (auto* call : info.calls) {
+ allInfo[module->getFunction(call->target)->type].calls.push_back(call);
+ }
+
+ // For indirect calls, add each call_ref to the type the call_ref uses.
+ for (auto* callRef : info.callRefs) {
+ auto calledType = callRef->target->type;
+ if (calledType != Type::unreachable) {
+ allInfo[calledType.getHeapType()].callRefs.push_back(callRef);
+ }
+ }
+
+ // A parameter used in this function is used in the heap type - just one
+ // function is enough to prevent the parameter from being removed.
+ auto& allUsedParams = allInfo[func->type].usedParams;
+ for (auto index : info.usedParams) {
+ allUsedParams.insert(index);
+ }
+ sigFuncs[func->type].push_back(func);
+ }
+
+ // Exported functions cannot be modified.
+ for (auto& exp : module->exports) {
+ if (exp->kind == ExternalKind::Function) {
+ auto* func = module->getFunction(exp->value);
+ allInfo[func->type].markUnoptimizable(func);
+ }
+ }
+
+ // Find parameters to prune.
+ for (auto& [type, funcs] : sigFuncs) {
+ auto sig = type.getSignature();
+ auto& info = allInfo[type];
+ auto numParams = sig.params.size();
+ if (info.usedParams.size() == numParams) {
+ // All parameters are used, give up on this one.
+ continue;
+ }
+
+ // We found possible work! Find the specific params that are unused try to
+ // prune them.
+ SortedVector unusedParams;
+ for (Index i = 0; i < numParams; i++) {
+ if (info.usedParams.count(i) == 0) {
+ unusedParams.insert(i);
+ }
+ }
+
+ auto oldParams = sig.params;
+ auto removedIndexes = ParamUtils::removeParameters(
+ funcs, unusedParams, info.calls, info.callRefs, module, runner);
+ if (removedIndexes.empty()) {
+ continue;
+ }
+
+ // Success! Update the types.
+ std::vector<Type> newParams;
+ for (Index i = 0; i < numParams; i++) {
+ if (!removedIndexes.has(i)) {
+ newParams.push_back(oldParams[i]);
+ }
+ }
+
+ // Create a new signature. When the TypeRewriter operates below it will
+ // modify the existing heap type in place to change its signature to this
+ // one (which preserves identity, that is, even if after pruning the new
+ // signature is structurally identical to another one, it will remain
+ // nominally different from those).
+ newSignatures[type] = Signature(Type(newParams), sig.results);
+
+ // removeParameters() updates the type as it goes, but in this pass we
+ // need the type to match the other locations, nominally. That is, we need
+ // all the functions of a particular type to still have the same type
+ // after this operation, and that must be the exact same type at the
+ // relevant call_refs and so forth. The TypeRewriter below will do the
+ // right thing as it rewrites everything all at once, so we do not want
+ // the type to be modified by removeParameters(), and so we undo the type
+ // it made.
+ //
+ // Note that we cannot just ask removeParameters() to not update the type,
+ // as it adds a new local there, whose index depends on the type (which
+ // contains the # of parameters, and that determine where non-parameter
+ // local indexes begin). Rather than have it update the type and then undo
+ // that, which would add more complexity in that method, undo the change
+ // here.
+ for (auto* func : funcs) {
+ func->type = type;
+ }
+ }
+
+ // Rewrite the types.
+ GlobalTypeRewriter::updateSignatures(newSignatures, *module);
+ }
+};
+
+} // anonymous namespace
+
+Pass* createSignaturePruningPass() { return new SignaturePruning(); }
+
+} // namespace wasm