diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/passes/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/passes/GlobalStructInference.cpp | 244 | ||||
-rw-r--r-- | src/passes/pass.cpp | 2 | ||||
-rw-r--r-- | src/passes/passes.h | 1 |
4 files changed, 248 insertions, 0 deletions
diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt index c83e95d9e..f74dd4f0b 100644 --- a/src/passes/CMakeLists.txt +++ b/src/passes/CMakeLists.txt @@ -38,6 +38,7 @@ set(passes_SOURCES FuncCastEmulation.cpp GenerateDynCalls.cpp GlobalRefining.cpp + GlobalStructInference.cpp GlobalTypeOptimization.cpp Heap2Local.cpp I64ToI32Lowering.cpp diff --git a/src/passes/GlobalStructInference.cpp b/src/passes/GlobalStructInference.cpp new file mode 100644 index 000000000..42fadf295 --- /dev/null +++ b/src/passes/GlobalStructInference.cpp @@ -0,0 +1,244 @@ +/* + * Copyright 2022 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +// Finds types which are only created in assignments to immutable globals. For +// such types we can replace a struct.get with this pattern: +// +// (struct.get $foo i +// (..ref..)) +// => +// (select +// (value1) +// (value2) +// (ref.eq +// (..ref..) +// (global.get $global1))) +// +// That is a valid transformation if there are only two struct.news of $foo, it +// is created in two immutable globals $global1 and $global2, the field is +// immutable, the values of field |i| in them are value1 and value2 +// respectively, and $foo has no subtypes. In that situation, the reference must +// be one of those two, so we can compare the reference to the globals and pick +// the right value there. (We can also handle subtypes, if we look at their +// values as well, see below.) +// +// The benefit of this optimization is primarily in the case of constant values +// that we can heavily optimize, like function references (constant function +// refs let us inline, etc.). Function references cannot be directly compared, +// so we cannot use ConstantFieldPropagation or such with an extension to +// multiple values, as the select pattern shown above can't be used - it needs a +// comparison. But we can compare structs, so if the function references are in +// vtables, and the vtables follow the above pattern, then we can optimize. +// + +#include "ir/find_all.h" +#include "ir/module-utils.h" +#include "ir/subtypes.h" +#include "pass.h" +#include "wasm-builder.h" +#include "wasm.h" + +namespace wasm { + +namespace { + +struct GlobalStructInference : public Pass { + // Maps optimizable struct types to the globals whose init is a struct.new of + // them. If a global is not present here, it cannot be optimized. + std::unordered_map<HeapType, std::vector<Name>> typeGlobals; + + void run(PassRunner* runner, Module* module) override { + if (getTypeSystem() != TypeSystem::Nominal) { + Fatal() << "GlobalStructInference requires nominal typing"; + } + + // First, find all the information we need. We need to know which struct + // types are created in functions, because we will not be able to optimize + // those. + + using HeapTypes = std::unordered_set<HeapType>; + + ModuleUtils::ParallelFunctionAnalysis<HeapTypes> analysis( + *module, [&](Function* func, HeapTypes& types) { + if (func->imported()) { + return; + } + + for (auto* structNew : FindAll<StructNew>(func->body).list) { + auto type = structNew->type; + if (type.isRef()) { + types.insert(type.getHeapType()); + } + } + }); + + // We cannot optimize types that appear in a struct.new in a function, which + // we just collected and merge now. + HeapTypes unoptimizable; + + for (auto& [func, types] : analysis.map) { + for (auto type : types) { + unoptimizable.insert(type); + } + } + + // Process the globals. + for (auto& global : module->globals) { + if (global->imported()) { + continue; + } + + // We cannot optimize a type that appears in a non-toplevel location in a + // global init. + for (auto* structNew : FindAll<StructNew>(global->init).list) { + auto type = structNew->type; + if (type.isRef() && structNew != global->init) { + unoptimizable.insert(type.getHeapType()); + } + } + + if (!global->init->type.isRef()) { + continue; + } + + auto type = global->init->type.getHeapType(); + + // We cannot optimize mutable globals. + if (global->mutable_) { + unoptimizable.insert(type); + continue; + } + + // Finally, if this is a struct.new then it is one we can optimize; note + // it. + if (global->init->is<StructNew>()) { + typeGlobals[type].push_back(global->name); + } + } + + // A struct.get might also read from any of the subtypes. As a result, an + // unoptimizable type makes all its supertypes unoptimizable as well. + // TODO: this could be specific per field (and not all supers have all + // fields) + for (auto type : unoptimizable) { + while (1) { + typeGlobals.erase(type); + auto super = type.getSuperType(); + if (!super) { + break; + } + type = *super; + } + } + + // Similarly, propagate global names: if one type has [global1], then a get + // of any supertype might access that, so propagate to them. + auto typeGlobalsCopy = typeGlobals; + for (auto& [type, globals] : typeGlobalsCopy) { + auto curr = type; + while (1) { + auto super = curr.getSuperType(); + if (!super) { + break; + } + curr = *super; + for (auto global : globals) { + typeGlobals[curr].push_back(global); + } + } + } + + if (typeGlobals.empty()) { + // We found nothing we can optimize. + return; + } + + // Optimize based on the above. + struct FunctionOptimizer + : public WalkerPass<PostWalker<FunctionOptimizer>> { + bool isFunctionParallel() override { return true; } + + Pass* create() override { return new FunctionOptimizer(parent); } + + FunctionOptimizer(GlobalStructInference& parent) : parent(parent) {} + + void visitStructGet(StructGet* curr) { + auto type = curr->ref->type; + if (type == Type::unreachable) { + return; + } + + auto iter = parent.typeGlobals.find(type.getHeapType()); + if (iter == parent.typeGlobals.end()) { + return; + } + + auto& globals = iter->second; + + // TODO: more sizes + if (globals.size() != 2) { + return; + } + + // Check if the relevant fields contain constants, and are immutable. + auto& wasm = *getModule(); + auto fieldIndex = curr->index; + auto& field = type.getHeapType().getStruct().fields[fieldIndex]; + if (field.mutable_ == Mutable) { + return; + } + auto fieldType = field.type; + std::vector<Literal> values; + for (Index i = 0; i < globals.size(); i++) { + auto* structNew = wasm.getGlobal(globals[i])->init->cast<StructNew>(); + if (structNew->isWithDefault()) { + values.push_back(Literal::makeZero(fieldType)); + } else { + auto* init = structNew->operands[fieldIndex]; + if (!Properties::isConstantExpression(init)) { + // Non-constant; give up entirely. + return; + } + values.push_back(Properties::getLiteral(init)); + } + } + + // Excellent, we can optimize here! Emit a select. + // + // Note that we must trap on null, so add a ref.as_non_null here. + Builder builder(wasm); + replaceCurrent(builder.makeSelect( + builder.makeRefEq(builder.makeRefAs(RefAsNonNull, curr->ref), + builder.makeGlobalGet( + globals[0], wasm.getGlobal(globals[0])->type)), + builder.makeConstantExpression(values[0]), + builder.makeConstantExpression(values[1]))); + } + + private: + GlobalStructInference& parent; + }; + + FunctionOptimizer(*this).run(runner, module); + } +}; + +} // anonymous namespace + +Pass* createGlobalStructInferencePass() { return new GlobalStructInference(); } + +} // namespace wasm diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index 04947ae07..85a240dc3 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -170,6 +170,8 @@ void PassRegistry::registerPasses() { "global-refining", "refine the types of globals", createGlobalRefiningPass); registerPass( "gto", "globally optimize GC types", createGlobalTypeOptimizationPass); + registerPass( + "gsi", "globally optimize struct values", createGlobalStructInferencePass); registerPass("type-refining", "apply more specific subtypes to type fields where possible", createTypeRefiningPass); diff --git a/src/passes/passes.h b/src/passes/passes.h index d7a6f9989..2c73ed91e 100644 --- a/src/passes/passes.h +++ b/src/passes/passes.h @@ -52,6 +52,7 @@ Pass* createGenerateDynCallsPass(); Pass* createGenerateI64DynCallsPass(); Pass* createGenerateStackIRPass(); Pass* createGlobalRefiningPass(); +Pass* createGlobalStructInferencePass(); Pass* createGlobalTypeOptimizationPass(); Pass* createHeap2LocalPass(); Pass* createI64ToI32LoweringPass(); |