summaryrefslogtreecommitdiff
path: root/src/passes/OptimizeInstructions.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/passes/OptimizeInstructions.cpp')
-rw-r--r--src/passes/OptimizeInstructions.cpp186
1 files changed, 136 insertions, 50 deletions
diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp
index efa9409a7..9fa059327 100644
--- a/src/passes/OptimizeInstructions.cpp
+++ b/src/passes/OptimizeInstructions.cpp
@@ -22,66 +22,152 @@
#include <wasm.h>
#include <pass.h>
+#include <wasm-s-parser.h>
namespace wasm {
-struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, Visitor<OptimizeInstructions>>> {
- bool isFunctionParallel() override { return true; }
+Name I32_EXPR = "i32.expr",
+ I64_EXPR = "i64.expr",
+ F32_EXPR = "f32.expr",
+ F64_EXPR = "f64.expr",
+ ANY_EXPR = "any.expr";
- Pass* create() override { return new OptimizeInstructions; }
+// A pattern
+struct Pattern {
+ Expression* input;
+ Expression* output;
+
+ Pattern(Expression* input, Expression* output) : input(input), output(output) {}
+};
+
+// Database of patterns
+struct PatternDatabase {
+ Module wasm;
+
+ char* input;
+
+ std::map<Expression::Id, std::vector<Pattern>> patternMap; // root expression id => list of all patterns for it TODO optimize more
+
+ PatternDatabase() {
+ // TODO: do this on first use, with a lock, to avoid startup pause
+ // generate module
+ input = strdup(
+ #include "OptimizeInstructions.wast.processed"
+ );
+ SExpressionParser parser(input);
+ Element& root = *parser.root;
+ SExpressionWasmBuilder builder(wasm, *root[0]);
+ // parse module form
+ auto* func = wasm.getFunction("patterns");
+ auto* body = func->body->cast<Block>();
+ for (auto* item : body->list) {
+ auto* pair = item->cast<Block>();
+ patternMap[pair->list[0]->_id].emplace_back(pair->list[0], pair->list[1]);
+ }
+ }
+
+ ~PatternDatabase() {
+ free(input);
+ };
+};
+
+static PatternDatabase database;
+
+// Check for matches and apply them
+struct Match {
+ Module& wasm;
+ Pattern& pattern;
+
+ Match(Module& wasm, Pattern& pattern) : wasm(wasm), pattern(pattern) {}
+
+ std::vector<Expression*> wildcards; // id in i32.any(id) etc. => the expression it represents in this match
- void visitIf(If* curr) {
- // flip branches to get rid of an i32.eqz
- if (curr->ifFalse) {
- auto condition = curr->condition->dynCast<Unary>();
- if (condition && condition->op == EqZInt32 && condition->value->type == i32) {
- curr->condition = condition->value;
- std::swap(curr->ifTrue, curr->ifFalse);
+ // Comparing/checking
+
+ // Check if we can match to this pattern, updating ourselves with the info if so
+ bool check(Expression* seen) {
+ // compare seen to the pattern input, doing a special operation for our "wildcards"
+ assert(wildcards.size() == 0);
+ return ExpressionAnalyzer::flexibleEqual(pattern.input, seen, *this);
+ }
+
+ bool compare(Expression* subInput, Expression* subSeen) {
+ CallImport* call = subInput->dynCast<CallImport>();
+ if (!call || call->operands.size() != 1 || call->operands[0]->type != i32 || !call->operands[0]->is<Const>()) return false;
+ Index index = call->operands[0]->cast<Const>()->value.geti32();
+ // handle our special functions
+ auto checkMatch = [&](WasmType type) {
+ if (type != none && subSeen->type != type) return false;
+ while (index >= wildcards.size()) {
+ wildcards.push_back(nullptr);
}
+ if (!wildcards[index]) {
+ // new wildcard
+ wildcards[index] = subSeen; // NB: no need to copy
+ return true;
+ } else {
+ // We are seeing this index for a second or later time, check it matches
+ return ExpressionAnalyzer::equal(subSeen, wildcards[index]);
+ };
+ };
+ if (call->target == I32_EXPR) {
+ if (checkMatch(i32)) return true;
+ } else if (call->target == I64_EXPR) {
+ if (checkMatch(i64)) return true;
+ } else if (call->target == F32_EXPR) {
+ if (checkMatch(f32)) return true;
+ } else if (call->target == F64_EXPR) {
+ if (checkMatch(f64)) return true;
+ } else if (call->target == ANY_EXPR) {
+ if (checkMatch(none)) return true;
}
+ return false;
+ }
+
+ // Applying/copying
+
+ // Apply the match, generate an output expression from the matched input, performing substitutions as necessary
+ Expression* apply() {
+ return ExpressionManipulator::flexibleCopy(pattern.output, wasm, *this);
}
- void visitUnary(Unary* curr) {
- if (curr->op == EqZInt32) {
- // fold comparisons that flow into an EqZ
- auto* child = curr->value->dynCast<Binary>();
- if (child && (child->type == i32 || child->type == i64)) {
- switch (child->op) {
- case EqInt32: child->op = NeInt32; break;
- case NeInt32: child->op = EqInt32; break;
- case LtSInt32: child->op = GeSInt32; break;
- case LtUInt32: child->op = GeUInt32; break;
- case LeSInt32: child->op = GtSInt32; break;
- case LeUInt32: child->op = GtUInt32; break;
- case GtSInt32: child->op = LeSInt32; break;
- case GtUInt32: child->op = LeUInt32; break;
- case GeSInt32: child->op = LtSInt32; break;
- case GeUInt32: child->op = LtUInt32; break;
- case EqInt64: child->op = NeInt64; break;
- case NeInt64: child->op = EqInt64; break;
- case LtSInt64: child->op = GeSInt64; break;
- case LtUInt64: child->op = GeUInt64; break;
- case LeSInt64: child->op = GtSInt64; break;
- case LeUInt64: child->op = GtUInt64; break;
- case GtSInt64: child->op = LeSInt64; break;
- case GtUInt64: child->op = LeUInt64; break;
- case GeSInt64: child->op = LtSInt64; break;
- case GeUInt64: child->op = LtUInt64; break;
- case EqFloat32: child->op = NeFloat32; break;
- case NeFloat32: child->op = EqFloat32; break;
- case LtFloat32: child->op = GeFloat32; break;
- case LeFloat32: child->op = GtFloat32; break;
- case GtFloat32: child->op = LeFloat32; break;
- case GeFloat32: child->op = LtFloat32; break;
- case EqFloat64: child->op = NeFloat64; break;
- case NeFloat64: child->op = EqFloat64; break;
- case LtFloat64: child->op = GeFloat64; break;
- case LeFloat64: child->op = GtFloat64; break;
- case GtFloat64: child->op = LeFloat64; break;
- case GeFloat64: child->op = LtFloat64; break;
- default: return;
+
+ // When copying a wildcard, perform the substitution.
+ // TODO: we can reuse nodes, not copying a wildcard when it appears just once, and we can reuse other individual nodes when they are discarded anyhow.
+ Expression* copy(Expression* curr) {
+ CallImport* call = curr->dynCast<CallImport>();
+ if (!call || call->operands.size() != 1 || call->operands[0]->type != i32 || !call->operands[0]->is<Const>()) return nullptr;
+ Index index = call->operands[0]->cast<Const>()->value.geti32();
+ // handle our special functions
+ if (call->target == I32_EXPR || call->target == I64_EXPR || call->target == F32_EXPR || call->target == F64_EXPR || call->target == ANY_EXPR) {
+ return ExpressionManipulator::copy(wildcards.at(index), wasm);
+ }
+ return nullptr;
+ }
+};
+
+// Main pass class
+struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, UnifiedExpressionVisitor<OptimizeInstructions>>> {
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override { return new OptimizeInstructions; }
+
+ void visitExpression(Expression* curr) {
+ // we may be able to apply multiple patterns, one may open opportunities that look deeper NB: patterns must not have cycles
+ while (1) {
+ auto iter = database.patternMap.find(curr->_id);
+ if (iter == database.patternMap.end()) return;
+ auto& patterns = iter->second;
+ bool more = false;
+ for (auto& pattern : patterns) {
+ Match match(*getModule(), pattern);
+ if (match.check(curr)) {
+ curr = match.apply();
+ replaceCurrent(curr);
+ more = true;
+ break; // exit pattern for loop, return to main while loop
}
- replaceCurrent(child);
}
+ if (!more) break;
}
}
};