summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/passes/CMakeLists.txt1
-rw-r--r--src/passes/GlobalStructInference.cpp244
-rw-r--r--src/passes/pass.cpp2
-rw-r--r--src/passes/passes.h1
4 files changed, 248 insertions, 0 deletions
diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt
index c83e95d9e..f74dd4f0b 100644
--- a/src/passes/CMakeLists.txt
+++ b/src/passes/CMakeLists.txt
@@ -38,6 +38,7 @@ set(passes_SOURCES
FuncCastEmulation.cpp
GenerateDynCalls.cpp
GlobalRefining.cpp
+ GlobalStructInference.cpp
GlobalTypeOptimization.cpp
Heap2Local.cpp
I64ToI32Lowering.cpp
diff --git a/src/passes/GlobalStructInference.cpp b/src/passes/GlobalStructInference.cpp
new file mode 100644
index 000000000..42fadf295
--- /dev/null
+++ b/src/passes/GlobalStructInference.cpp
@@ -0,0 +1,244 @@
+/*
+ * Copyright 2022 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//
+// Finds types which are only created in assignments to immutable globals. For
+// such types we can replace a struct.get with this pattern:
+//
+// (struct.get $foo i
+// (..ref..))
+// =>
+// (select
+// (value1)
+// (value2)
+// (ref.eq
+// (..ref..)
+// (global.get $global1)))
+//
+// That is a valid transformation if there are only two struct.news of $foo, it
+// is created in two immutable globals $global1 and $global2, the field is
+// immutable, the values of field |i| in them are value1 and value2
+// respectively, and $foo has no subtypes. In that situation, the reference must
+// be one of those two, so we can compare the reference to the globals and pick
+// the right value there. (We can also handle subtypes, if we look at their
+// values as well, see below.)
+//
+// The benefit of this optimization is primarily in the case of constant values
+// that we can heavily optimize, like function references (constant function
+// refs let us inline, etc.). Function references cannot be directly compared,
+// so we cannot use ConstantFieldPropagation or such with an extension to
+// multiple values, as the select pattern shown above can't be used - it needs a
+// comparison. But we can compare structs, so if the function references are in
+// vtables, and the vtables follow the above pattern, then we can optimize.
+//
+
+#include "ir/find_all.h"
+#include "ir/module-utils.h"
+#include "ir/subtypes.h"
+#include "pass.h"
+#include "wasm-builder.h"
+#include "wasm.h"
+
+namespace wasm {
+
+namespace {
+
+struct GlobalStructInference : public Pass {
+ // Maps optimizable struct types to the globals whose init is a struct.new of
+ // them. If a global is not present here, it cannot be optimized.
+ std::unordered_map<HeapType, std::vector<Name>> typeGlobals;
+
+ void run(PassRunner* runner, Module* module) override {
+ if (getTypeSystem() != TypeSystem::Nominal) {
+ Fatal() << "GlobalStructInference requires nominal typing";
+ }
+
+ // First, find all the information we need. We need to know which struct
+ // types are created in functions, because we will not be able to optimize
+ // those.
+
+ using HeapTypes = std::unordered_set<HeapType>;
+
+ ModuleUtils::ParallelFunctionAnalysis<HeapTypes> analysis(
+ *module, [&](Function* func, HeapTypes& types) {
+ if (func->imported()) {
+ return;
+ }
+
+ for (auto* structNew : FindAll<StructNew>(func->body).list) {
+ auto type = structNew->type;
+ if (type.isRef()) {
+ types.insert(type.getHeapType());
+ }
+ }
+ });
+
+ // We cannot optimize types that appear in a struct.new in a function, which
+ // we just collected and merge now.
+ HeapTypes unoptimizable;
+
+ for (auto& [func, types] : analysis.map) {
+ for (auto type : types) {
+ unoptimizable.insert(type);
+ }
+ }
+
+ // Process the globals.
+ for (auto& global : module->globals) {
+ if (global->imported()) {
+ continue;
+ }
+
+ // We cannot optimize a type that appears in a non-toplevel location in a
+ // global init.
+ for (auto* structNew : FindAll<StructNew>(global->init).list) {
+ auto type = structNew->type;
+ if (type.isRef() && structNew != global->init) {
+ unoptimizable.insert(type.getHeapType());
+ }
+ }
+
+ if (!global->init->type.isRef()) {
+ continue;
+ }
+
+ auto type = global->init->type.getHeapType();
+
+ // We cannot optimize mutable globals.
+ if (global->mutable_) {
+ unoptimizable.insert(type);
+ continue;
+ }
+
+ // Finally, if this is a struct.new then it is one we can optimize; note
+ // it.
+ if (global->init->is<StructNew>()) {
+ typeGlobals[type].push_back(global->name);
+ }
+ }
+
+ // A struct.get might also read from any of the subtypes. As a result, an
+ // unoptimizable type makes all its supertypes unoptimizable as well.
+ // TODO: this could be specific per field (and not all supers have all
+ // fields)
+ for (auto type : unoptimizable) {
+ while (1) {
+ typeGlobals.erase(type);
+ auto super = type.getSuperType();
+ if (!super) {
+ break;
+ }
+ type = *super;
+ }
+ }
+
+ // Similarly, propagate global names: if one type has [global1], then a get
+ // of any supertype might access that, so propagate to them.
+ auto typeGlobalsCopy = typeGlobals;
+ for (auto& [type, globals] : typeGlobalsCopy) {
+ auto curr = type;
+ while (1) {
+ auto super = curr.getSuperType();
+ if (!super) {
+ break;
+ }
+ curr = *super;
+ for (auto global : globals) {
+ typeGlobals[curr].push_back(global);
+ }
+ }
+ }
+
+ if (typeGlobals.empty()) {
+ // We found nothing we can optimize.
+ return;
+ }
+
+ // Optimize based on the above.
+ struct FunctionOptimizer
+ : public WalkerPass<PostWalker<FunctionOptimizer>> {
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override { return new FunctionOptimizer(parent); }
+
+ FunctionOptimizer(GlobalStructInference& parent) : parent(parent) {}
+
+ void visitStructGet(StructGet* curr) {
+ auto type = curr->ref->type;
+ if (type == Type::unreachable) {
+ return;
+ }
+
+ auto iter = parent.typeGlobals.find(type.getHeapType());
+ if (iter == parent.typeGlobals.end()) {
+ return;
+ }
+
+ auto& globals = iter->second;
+
+ // TODO: more sizes
+ if (globals.size() != 2) {
+ return;
+ }
+
+ // Check if the relevant fields contain constants, and are immutable.
+ auto& wasm = *getModule();
+ auto fieldIndex = curr->index;
+ auto& field = type.getHeapType().getStruct().fields[fieldIndex];
+ if (field.mutable_ == Mutable) {
+ return;
+ }
+ auto fieldType = field.type;
+ std::vector<Literal> values;
+ for (Index i = 0; i < globals.size(); i++) {
+ auto* structNew = wasm.getGlobal(globals[i])->init->cast<StructNew>();
+ if (structNew->isWithDefault()) {
+ values.push_back(Literal::makeZero(fieldType));
+ } else {
+ auto* init = structNew->operands[fieldIndex];
+ if (!Properties::isConstantExpression(init)) {
+ // Non-constant; give up entirely.
+ return;
+ }
+ values.push_back(Properties::getLiteral(init));
+ }
+ }
+
+ // Excellent, we can optimize here! Emit a select.
+ //
+ // Note that we must trap on null, so add a ref.as_non_null here.
+ Builder builder(wasm);
+ replaceCurrent(builder.makeSelect(
+ builder.makeRefEq(builder.makeRefAs(RefAsNonNull, curr->ref),
+ builder.makeGlobalGet(
+ globals[0], wasm.getGlobal(globals[0])->type)),
+ builder.makeConstantExpression(values[0]),
+ builder.makeConstantExpression(values[1])));
+ }
+
+ private:
+ GlobalStructInference& parent;
+ };
+
+ FunctionOptimizer(*this).run(runner, module);
+ }
+};
+
+} // anonymous namespace
+
+Pass* createGlobalStructInferencePass() { return new GlobalStructInference(); }
+
+} // namespace wasm
diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp
index 04947ae07..85a240dc3 100644
--- a/src/passes/pass.cpp
+++ b/src/passes/pass.cpp
@@ -170,6 +170,8 @@ void PassRegistry::registerPasses() {
"global-refining", "refine the types of globals", createGlobalRefiningPass);
registerPass(
"gto", "globally optimize GC types", createGlobalTypeOptimizationPass);
+ registerPass(
+ "gsi", "globally optimize struct values", createGlobalStructInferencePass);
registerPass("type-refining",
"apply more specific subtypes to type fields where possible",
createTypeRefiningPass);
diff --git a/src/passes/passes.h b/src/passes/passes.h
index d7a6f9989..2c73ed91e 100644
--- a/src/passes/passes.h
+++ b/src/passes/passes.h
@@ -52,6 +52,7 @@ Pass* createGenerateDynCallsPass();
Pass* createGenerateI64DynCallsPass();
Pass* createGenerateStackIRPass();
Pass* createGlobalRefiningPass();
+Pass* createGlobalStructInferencePass();
Pass* createGlobalTypeOptimizationPass();
Pass* createHeap2LocalPass();
Pass* createI64ToI32LoweringPass();