summaryrefslogtreecommitdiff
path: root/src/passes/param-utils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/passes/param-utils.cpp')
-rw-r--r--src/passes/param-utils.cpp177
1 files changed, 177 insertions, 0 deletions
diff --git a/src/passes/param-utils.cpp b/src/passes/param-utils.cpp
new file mode 100644
index 000000000..019df6d7f
--- /dev/null
+++ b/src/passes/param-utils.cpp
@@ -0,0 +1,177 @@
+/*
+ * Copyright 2022 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/function-utils.h"
+#include "ir/local-graph.h"
+#include "ir/type-updating.h"
+#include "support/sorted_vector.h"
+#include "wasm.h"
+
+namespace wasm::ParamUtils {
+
+std::unordered_set<Index> getUsedParams(Function* func) {
+ LocalGraph localGraph(func);
+
+ std::unordered_set<Index> usedParams;
+
+ for (auto& [get, sets] : localGraph.getSetses) {
+ if (!func->isParam(get->index)) {
+ continue;
+ }
+
+ for (auto* set : sets) {
+ // A nullptr value indicates there is no LocalSet* that sets the value,
+ // so it must be the parameter value.
+ if (!set) {
+ usedParams.insert(get->index);
+ }
+ }
+ }
+
+ return usedParams;
+}
+
+bool removeParameter(const std::vector<Function*> funcs,
+ Index index,
+ const std::vector<Call*>& calls,
+ const std::vector<CallRef*>& callRefs,
+ Module* module,
+ PassRunner* runner) {
+ assert(funcs.size() > 0);
+ auto* first = funcs[0];
+#ifndef NDEBUG
+ for (auto* func : funcs) {
+ assert(func->type == first->type);
+ }
+#endif
+
+ // Check if none of the calls has a param with side effects that we cannot
+ // remove (as if we can remove them, we will simply do that when we remove the
+ // parameter). Note: flattening the IR beforehand can help here.
+ bool callParamsAreValid =
+ std::none_of(calls.begin(), calls.end(), [&](Call* call) {
+ auto* operand = call->operands[index];
+ return EffectAnalyzer(runner->options, *module, operand)
+ .hasUnremovableSideEffects();
+ });
+ if (!callParamsAreValid) {
+ return false;
+ }
+
+ // The type must be valid for us to handle as a local (since we
+ // replace the parameter with a local).
+ // TODO: if there are no references at all, we can avoid creating a
+ // local
+ bool typeIsValid = TypeUpdating::canHandleAsLocal(first->getLocalType(index));
+ if (!typeIsValid) {
+ return false;
+ }
+
+ // We can do it!
+
+ // Remove the parameter from the function. We must add a new local
+ // for uses of the parameter, but cannot make it use the same index
+ // (in general).
+ auto paramsType = first->getParams();
+ std::vector<Type> params(paramsType.begin(), paramsType.end());
+ auto type = params[index];
+ params.erase(params.begin() + index);
+ // TODO: parallelize some of these loops?
+ for (auto* func : funcs) {
+ func->setParams(Type(params));
+
+ // It's cumbersome to adjust local names - TODO don't clear them?
+ Builder::clearLocalNames(func);
+ }
+ std::vector<Index> newIndexes;
+ for (auto* func : funcs) {
+ newIndexes.push_back(Builder::addVar(func, type));
+ }
+ // Update local operations.
+ struct LocalUpdater : public PostWalker<LocalUpdater> {
+ Index removedIndex;
+ Index newIndex;
+ LocalUpdater(Function* func, Index removedIndex, Index newIndex)
+ : removedIndex(removedIndex), newIndex(newIndex) {
+ walk(func->body);
+ }
+ void visitLocalGet(LocalGet* curr) { updateIndex(curr->index); }
+ void visitLocalSet(LocalSet* curr) { updateIndex(curr->index); }
+ void updateIndex(Index& index) {
+ if (index == removedIndex) {
+ index = newIndex;
+ } else if (index > removedIndex) {
+ index--;
+ }
+ }
+ };
+ for (Index i = 0; i < funcs.size(); i++) {
+ auto* func = funcs[i];
+ if (!func->imported()) {
+ LocalUpdater(funcs[i], index, newIndexes[i]);
+ TypeUpdating::handleNonDefaultableLocals(func, *module);
+ }
+ }
+
+ // Remove the arguments from the calls.
+ for (auto* call : calls) {
+ call->operands.erase(call->operands.begin() + index);
+ }
+ for (auto* call : callRefs) {
+ call->operands.erase(call->operands.begin() + index);
+ }
+
+ return true;
+}
+
+SortedVector removeParameters(const std::vector<Function*> funcs,
+ SortedVector indexes,
+ const std::vector<Call*>& calls,
+ const std::vector<CallRef*>& callRefs,
+ Module* module,
+ PassRunner* runner) {
+ if (indexes.empty()) {
+ return {};
+ }
+
+ assert(funcs.size() > 0);
+ auto* first = funcs[0];
+#ifndef NDEBUG
+ for (auto* func : funcs) {
+ assert(func->type == first->type);
+ }
+#endif
+
+ // Iterate downwards, as we may remove more than one, and going forwards would
+ // alter the indexes after us.
+ Index i = first->getNumParams() - 1;
+ SortedVector removed;
+ while (1) {
+ if (indexes.has(i)) {
+ if (removeParameter(funcs, i, calls, callRefs, module, runner)) {
+ // Success!
+ removed.insert(i);
+ }
+ }
+ if (i == 0) {
+ break;
+ }
+ i--;
+ }
+ return removed;
+}
+
+} // namespace wasm::ParamUtils