summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/passes/CMakeLists.txt1
-rw-r--r--src/passes/ConstantFieldPropagation.cpp456
-rw-r--r--src/passes/pass.cpp7
-rw-r--r--src/passes/passes.h1
-rw-r--r--src/wasm-type.h2
-rw-r--r--src/wasm/wasm-type.cpp3
6 files changed, 470 insertions, 0 deletions
diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt
index 9cc5e1f96..ad5600f1b 100644
--- a/src/passes/CMakeLists.txt
+++ b/src/passes/CMakeLists.txt
@@ -15,6 +15,7 @@ set(passes_SOURCES
CoalesceLocals.cpp
CodePushing.cpp
CodeFolding.cpp
+ ConstantFieldPropagation.cpp
ConstHoisting.cpp
DataFlowOpts.cpp
DeadArgumentElimination.cpp
diff --git a/src/passes/ConstantFieldPropagation.cpp b/src/passes/ConstantFieldPropagation.cpp
new file mode 100644
index 000000000..621ceec2a
--- /dev/null
+++ b/src/passes/ConstantFieldPropagation.cpp
@@ -0,0 +1,456 @@
+/*
+ * Copyright 2021 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//
+// Find struct fields that are always written to with a constant value, and
+// replace gets of them with that value.
+//
+// For example, if we have a vtable of type T, and we always create it with one
+// of the fields containing a ref.func of the same function F, and there is no
+// write to that field of a different value (even using a subtype of T), then
+// anywhere we see a get of that field we can place a ref.func of F.
+//
+// FIXME: This pass assumes a closed world. When we start to allow multi-module
+// wasm GC programs we need to check for type escaping.
+//
+
+#include "ir/module-utils.h"
+#include "ir/properties.h"
+#include "ir/utils.h"
+#include "pass.h"
+#include "support/unique_deferring_queue.h"
+#include "wasm-builder.h"
+#include "wasm-traversal.h"
+#include "wasm.h"
+
+namespace wasm {
+
+namespace {
+
+// A nominal type always knows who its supertype is, if there is one; this class
+// provides the list of immediate subtypes.
+struct SubTypes {
+ SubTypes(Module& wasm) {
+ std::vector<HeapType> types;
+ std::unordered_map<HeapType, Index> typeIndices;
+ ModuleUtils::collectHeapTypes(wasm, types, typeIndices);
+ for (auto type : types) {
+ note(type);
+ }
+ }
+
+ const std::unordered_set<HeapType>& getSubTypes(HeapType type) {
+ return typeSubTypes[type];
+ }
+
+private:
+ // Add a type to the graph.
+ void note(HeapType type) {
+ HeapType super;
+ if (type.getSuperType(super)) {
+ typeSubTypes[super].insert(type);
+ }
+ }
+
+ // Maps a type to its subtypes.
+ std::unordered_map<HeapType, std::unordered_set<HeapType>> typeSubTypes;
+};
+
+// Represents data about what constant values are possible in a particular
+// place. There may be no values, or one, or many, or if a non-constant value is
+// possible, then all we can say is that the value is "unknown" - it can be
+// anything.
+//
+// Currently this just looks for a single constant value, and even two constant
+// values are treated as unknown. It may be worth optimizing more than that TODO
+struct PossibleConstantValues {
+ // Note a written value as we see it, and update our internal knowledge based
+ // on it and all previous values noted.
+ void note(Literal curr) {
+ if (!noted) {
+ // This is the first value.
+ value = curr;
+ noted = true;
+ return;
+ }
+
+ // This is a subsequent value. Check if it is different from all previous
+ // ones.
+ if (curr != value) {
+ noteUnknown();
+ }
+ }
+
+ // Notes a value that is unknown - it can be anything. We have failed to
+ // identify a constant value here.
+ void noteUnknown() {
+ value = Literal(Type::none);
+ noted = true;
+ }
+
+ // Combine the information in a given PossibleConstantValues to this one. This
+ // is the same as if we have called note*() on us with all the history of
+ // calls to that other object.
+ //
+ // Returns whether we changed anything.
+ bool combine(const PossibleConstantValues& other) {
+ if (!other.noted) {
+ return false;
+ }
+ if (!noted) {
+ *this = other;
+ return other.noted;
+ }
+ if (!isConstant()) {
+ return false;
+ }
+ if (!other.isConstant() || getConstantValue() != other.getConstantValue()) {
+ noteUnknown();
+ return true;
+ }
+ return false;
+ }
+
+ // Check if all the values are identical and constant.
+ bool isConstant() const { return noted && value.type.isConcrete(); }
+
+ // Returns the single constant value.
+ Literal getConstantValue() const {
+ assert(isConstant());
+ return value;
+ }
+
+ // Returns whether we have ever noted a value.
+ bool hasNoted() const { return noted; }
+
+ void dump(std::ostream& o) {
+ o << '[';
+ if (!hasNoted()) {
+ o << "unwritten";
+ } else if (!isConstant()) {
+ o << "unknown";
+ } else {
+ o << value;
+ }
+ o << ']';
+ }
+
+private:
+ // Whether we have noted any values at all.
+ bool noted = false;
+
+ // The one value we have seen, if there is one. If we realize there is no
+ // single constant value here, we make this have a non-concrete (impossible)
+ // type to indicate that. Otherwise, a concrete type indicates we have a
+ // constant value.
+ Literal value;
+};
+
+// A vector of PossibleConstantValues. One such vector will be used per struct
+// type, where each element in the vector represents a field. We always assume
+// that the vectors are pre-initialized to the right length before accessing any
+// data, which this class enforces using assertions, and which is implemented in
+// StructValuesMap.
+struct StructValues : public std::vector<PossibleConstantValues> {
+ PossibleConstantValues& operator[](size_t index) {
+ assert(index < size());
+ return std::vector<PossibleConstantValues>::operator[](index);
+ }
+};
+
+// Map of types to information about the values their fields can take.
+// Concretely, this maps a type to a StructValues which has one element per
+// field.
+struct StructValuesMap : public std::unordered_map<HeapType, StructValues> {
+ // When we access an item, if it does not already exist, create it with a
+ // vector of the right length for that type.
+ StructValues& operator[](HeapType type) {
+ auto inserted = insert({type, {}});
+ auto& values = inserted.first->second;
+ if (inserted.second) {
+ values.resize(type.getStruct().fields.size());
+ }
+ return values;
+ }
+
+ void dump(std::ostream& o) {
+ o << "dump " << this << '\n';
+ for (auto& kv : (*this)) {
+ auto type = kv.first;
+ auto& vec = kv.second;
+ o << "dump " << type << " " << &vec << ' ';
+ for (auto x : vec) {
+ x.dump(o);
+ o << " ";
+ };
+ o << '\n';
+ }
+ }
+};
+
+// Map of functions to their field value infos. We compute those in parallel,
+// then later we will merge them all.
+using FunctionStructValuesMap = std::unordered_map<Function*, StructValuesMap>;
+
+// Scan each function to note all its writes to struct fields.
+struct Scanner : public WalkerPass<PostWalker<Scanner>> {
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override { return new Scanner(functionInfos); }
+
+ Scanner(FunctionStructValuesMap& functionInfos)
+ : functionInfos(functionInfos) {}
+
+ void visitStructNew(StructNew* curr) {
+ auto type = curr->type;
+ if (type == Type::unreachable) {
+ return;
+ }
+
+ // Note writes to all the fields of the struct.
+ auto heapType = type.getHeapType();
+ auto& values = getStructValues(heapType);
+ auto& fields = heapType.getStruct().fields;
+ for (Index i = 0; i < fields.size(); i++) {
+ auto& fieldValues = values[i];
+ if (curr->isWithDefault()) {
+ fieldValues.note(Literal::makeZero(fields[i].type));
+ } else {
+ noteExpression(curr->operands[i], fieldValues);
+ }
+ }
+ }
+
+ void visitStructSet(StructSet* curr) {
+ auto type = curr->ref->type;
+ if (type == Type::unreachable) {
+ return;
+ }
+
+ // Note a write to this field of the struct.
+ auto heapType = type.getHeapType();
+ noteExpression(curr->value, getStructValues(heapType)[curr->index]);
+ }
+
+private:
+ FunctionStructValuesMap& functionInfos;
+
+ StructValues& getStructValues(HeapType type) {
+ return functionInfos[getFunction()][type];
+ }
+
+ // Note a value, checking whether it is a constant or not.
+ void noteExpression(Expression* expr, PossibleConstantValues& info) {
+ expr =
+ Properties::getFallthrough(expr, getPassOptions(), getModule()->features);
+ if (!Properties::isConstantExpression(expr)) {
+ info.noteUnknown();
+ } else {
+ info.note(Properties::getLiteral(expr));
+ }
+ }
+};
+
+// Optimize struct gets based on what we've learned about writes.
+//
+// TODO Aside from writes, we could use information like whether any struct of
+// this type has even been created (to handle the case of struct.sets but
+// no struct.news).
+struct FunctionOptimizer : public WalkerPass<PostWalker<FunctionOptimizer>> {
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override { return new FunctionOptimizer(infos); }
+
+ FunctionOptimizer(StructValuesMap& infos) : infos(infos) {}
+
+ void visitStructGet(StructGet* curr) {
+ auto type = curr->ref->type;
+ if (type == Type::unreachable) {
+ return;
+ }
+
+ Builder builder(*getModule());
+
+ // Find the info for this field, and see if we can optimize. First, see if
+ // there is any information for this heap type at all. If there isn't, it is
+ // as if nothing was ever noted for that field.
+ PossibleConstantValues info;
+ assert(!info.hasNoted());
+ auto iter = infos.find(type.getHeapType());
+ if (iter != infos.end()) {
+ // There is information on this type, fetch it.
+ info = iter->second[curr->index];
+ }
+
+ if (!info.hasNoted()) {
+ // This field is never written at all. That means that we do not even
+ // construct any data of this type, and so it is a logic error to reach
+ // this location in the code. (Unless we are in an open-world
+ // situation, which we assume we are not in.) Replace this get with a
+ // trap. Note that we do not need to care about the nullability of the
+ // reference, as if it should have trapped, we are replacing it with
+ // another trap, which we allow to reorder (but we do need to care about
+ // side effects in the reference, so keep it around).
+ replaceCurrent(builder.makeSequence(builder.makeDrop(curr->ref),
+ builder.makeUnreachable()));
+ changed = true;
+ return;
+ }
+
+ // If the value is not a constant, then it is unknown and we must give up.
+ if (!info.isConstant()) {
+ return;
+ }
+
+ // We can do this! Replace the get with a trap on a null reference using a
+ // ref.as_non_null (we need to trap as the get would have done so), plus the
+ // constant value. (Leave it to further optimizations to get rid of the
+ // ref.)
+ replaceCurrent(builder.makeSequence(
+ builder.makeDrop(builder.makeRefAs(RefAsNonNull, curr->ref)),
+ builder.makeConstantExpression(info.getConstantValue())));
+ changed = true;
+ }
+
+ void doWalkFunction(Function* func) {
+ WalkerPass<PostWalker<FunctionOptimizer>>::doWalkFunction(func);
+
+ // If we changed anything, we need to update parent types as types may have
+ // changed.
+ if (changed) {
+ ReFinalize().walkFunctionInModule(func, getModule());
+ }
+ }
+
+private:
+ StructValuesMap& infos;
+
+ bool changed = false;
+};
+
+struct ConstantFieldPropagation : public Pass {
+ void run(PassRunner* runner, Module* module) override {
+ if (getTypeSystem() != TypeSystem::Nominal) {
+ Fatal() << "ConstantFieldPropagation requires nominal typing";
+ }
+
+ // Find and analyze all writes inside each function.
+ FunctionStructValuesMap functionInfos;
+ for (auto& func : module->functions) {
+ // Initialize the data for each function, so that we can operate on this
+ // structure in parallel without modifying it.
+ functionInfos[func.get()];
+ }
+ Scanner scanner(functionInfos);
+ scanner.run(runner, module);
+ scanner.walkModuleCode(module);
+
+ // Combine the data from the functions.
+ StructValuesMap combinedInfos;
+ for (auto& kv : functionInfos) {
+ StructValuesMap& infos = kv.second;
+ for (auto& kv : infos) {
+ auto type = kv.first;
+ auto& info = kv.second;
+ for (Index i = 0; i < info.size(); i++) {
+ combinedInfos[type][i].combine(info[i]);
+ }
+ }
+ }
+
+ // Handle subtyping. |combinedInfo| so far contains data that represents
+ // each struct.new and struct.set's operation on the struct type used in
+ // that instruction. That is, if we do a struct.set to type T, the value was
+ // noted for type T. But our actual goal is to answer questions about
+ // struct.gets. Specifically, when later we see:
+ //
+ // (struct.get $A x (REF-1))
+ //
+ // Then we want to be aware of all the relevant struct.sets, that is, the
+ // sets that can write data that this get reads. Given a set
+ //
+ // (struct.set $B x (REF-2) (..value..))
+ //
+ // then
+ //
+ // 1. If $B is a subtype of $A, it is relevant: the get might read from a
+ // struct of type $B (i.e., REF-1 and REF-2 might be identical, and both
+ // be a struct of type $B).
+ // 2. If $B is a supertype of $A that still has the field x then it may
+ // also be relevant: since $A is a subtype of $B, the set may write to a
+ // struct of type $A (and again, REF-1 and REF-2 may be identical).
+ //
+ // Thus, if either $A <: $B or $B <: $A then we must consider the get and
+ // set to be relevant to each other. To make our later lookups for gets
+ // efficient, we therefore propagate information about the possible values
+ // in each field to both subtypes and supertypes.
+ //
+ // TODO: Model struct.new separately from struct.set. With new we actually
+ // do know the specific type being written to, which means a get is
+ // only relevant for a new if the get is of a subtype. That means we
+ // only need to propagate values from new to subtypes.
+ //
+ // TODO: A topological sort could avoid repeated work here perhaps.
+ SubTypes subTypes(*module);
+ UniqueDeferredQueue<HeapType> work;
+ for (auto& kv : combinedInfos) {
+ auto type = kv.first;
+ work.push(type);
+ }
+ while (!work.empty()) {
+ auto type = work.pop();
+ auto& infos = combinedInfos[type];
+
+ // Propagate shared fields to the supertype.
+ HeapType superType;
+ if (type.getSuperType(superType)) {
+ auto& superInfos = combinedInfos[superType];
+ auto& superFields = superType.getStruct().fields;
+ for (Index i = 0; i < superFields.size(); i++) {
+ if (superInfos[i].combine(infos[i])) {
+ work.push(superType);
+ }
+ }
+ }
+
+ // Propagate shared fields to the subtypes.
+ auto numFields = type.getStruct().fields.size();
+ for (auto subType : subTypes.getSubTypes(type)) {
+ auto& subInfos = combinedInfos[subType];
+ for (Index i = 0; i < numFields; i++) {
+ if (subInfos[i].combine(infos[i])) {
+ work.push(subType);
+ }
+ }
+ }
+ }
+
+ // Optimize.
+ // TODO: Skip this if we cannot optimize anything
+ FunctionOptimizer(combinedInfos).run(runner, module);
+
+ // TODO: Actually remove the field from the type, where possible? That might
+ // be best in another pass.
+ }
+};
+
+} // anonymous namespace
+
+Pass* createConstantFieldPropagationPass() {
+ return new ConstantFieldPropagation();
+}
+
+} // namespace wasm
diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp
index 137a89703..fa706dd10 100644
--- a/src/passes/pass.cpp
+++ b/src/passes/pass.cpp
@@ -102,6 +102,9 @@ void PassRegistry::registerPasses() {
registerPass("const-hoisting",
"hoist repeated constants to a local",
createConstHoistingPass);
+ registerPass("cfp",
+ "propagate constant struct field values",
+ createConstantFieldPropagationPass);
registerPass(
"dce", "removes unreachable code", createDeadCodeEliminationPass);
registerPass("dealign",
@@ -499,6 +502,10 @@ void PassRunner::addDefaultFunctionOptimizationPasses() {
void PassRunner::addDefaultGlobalOptimizationPrePasses() {
addIfNoDWARFIssues("duplicate-function-elimination");
addIfNoDWARFIssues("memory-packing");
+ if (wasm->features.hasGC() && getTypeSystem() == TypeSystem::Nominal &&
+ options.optimizeLevel >= 2) {
+ addIfNoDWARFIssues("cfp");
+ }
}
void PassRunner::addDefaultGlobalOptimizationPostPasses() {
diff --git a/src/passes/passes.h b/src/passes/passes.h
index c58259f86..9a9d6378d 100644
--- a/src/passes/passes.h
+++ b/src/passes/passes.h
@@ -30,6 +30,7 @@ Pass* createCoalesceLocalsWithLearningPass();
Pass* createCodeFoldingPass();
Pass* createCodePushingPass();
Pass* createConstHoistingPass();
+Pass* createConstantFieldPropagationPass();
Pass* createDAEPass();
Pass* createDAEOptimizingPass();
Pass* createDataFlowOptsPass();
diff --git a/src/wasm-type.h b/src/wasm-type.h
index ff1019e55..43a93b4aa 100644
--- a/src/wasm-type.h
+++ b/src/wasm-type.h
@@ -41,6 +41,8 @@ enum class TypeSystem {
// created. The default system is equirecursive.
void setTypeSystem(TypeSystem system);
+TypeSystem getTypeSystem();
+
// The types defined in this file. All of them are small and typically passed by
// value except for `Tuple` and `Struct`, which may own an unbounded amount of
// data.
diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp
index 998d1a96b..34d7e942f 100644
--- a/src/wasm/wasm-type.cpp
+++ b/src/wasm/wasm-type.cpp
@@ -43,8 +43,11 @@
namespace wasm {
static TypeSystem typeSystem = TypeSystem::Equirecursive;
+
void setTypeSystem(TypeSystem system) { typeSystem = system; }
+TypeSystem getTypeSystem() { return typeSystem; }
+
namespace {
struct TypeInfo {