diff options
Diffstat (limited to 'src/passes/OptimizeInstructions.cpp')
-rw-r--r-- | src/passes/OptimizeInstructions.cpp | 186 |
1 files changed, 136 insertions, 50 deletions
diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index efa9409a7..9fa059327 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -22,66 +22,152 @@ #include <wasm.h> #include <pass.h> +#include <wasm-s-parser.h> namespace wasm { -struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, Visitor<OptimizeInstructions>>> { - bool isFunctionParallel() override { return true; } +Name I32_EXPR = "i32.expr", + I64_EXPR = "i64.expr", + F32_EXPR = "f32.expr", + F64_EXPR = "f64.expr", + ANY_EXPR = "any.expr"; - Pass* create() override { return new OptimizeInstructions; } +// A pattern +struct Pattern { + Expression* input; + Expression* output; + + Pattern(Expression* input, Expression* output) : input(input), output(output) {} +}; + +// Database of patterns +struct PatternDatabase { + Module wasm; + + char* input; + + std::map<Expression::Id, std::vector<Pattern>> patternMap; // root expression id => list of all patterns for it TODO optimize more + + PatternDatabase() { + // TODO: do this on first use, with a lock, to avoid startup pause + // generate module + input = strdup( + #include "OptimizeInstructions.wast.processed" + ); + SExpressionParser parser(input); + Element& root = *parser.root; + SExpressionWasmBuilder builder(wasm, *root[0]); + // parse module form + auto* func = wasm.getFunction("patterns"); + auto* body = func->body->cast<Block>(); + for (auto* item : body->list) { + auto* pair = item->cast<Block>(); + patternMap[pair->list[0]->_id].emplace_back(pair->list[0], pair->list[1]); + } + } + + ~PatternDatabase() { + free(input); + }; +}; + +static PatternDatabase database; + +// Check for matches and apply them +struct Match { + Module& wasm; + Pattern& pattern; + + Match(Module& wasm, Pattern& pattern) : wasm(wasm), pattern(pattern) {} + + std::vector<Expression*> wildcards; // id in i32.any(id) etc. => the expression it represents in this match - void visitIf(If* curr) { - // flip branches to get rid of an i32.eqz - if (curr->ifFalse) { - auto condition = curr->condition->dynCast<Unary>(); - if (condition && condition->op == EqZInt32 && condition->value->type == i32) { - curr->condition = condition->value; - std::swap(curr->ifTrue, curr->ifFalse); + // Comparing/checking + + // Check if we can match to this pattern, updating ourselves with the info if so + bool check(Expression* seen) { + // compare seen to the pattern input, doing a special operation for our "wildcards" + assert(wildcards.size() == 0); + return ExpressionAnalyzer::flexibleEqual(pattern.input, seen, *this); + } + + bool compare(Expression* subInput, Expression* subSeen) { + CallImport* call = subInput->dynCast<CallImport>(); + if (!call || call->operands.size() != 1 || call->operands[0]->type != i32 || !call->operands[0]->is<Const>()) return false; + Index index = call->operands[0]->cast<Const>()->value.geti32(); + // handle our special functions + auto checkMatch = [&](WasmType type) { + if (type != none && subSeen->type != type) return false; + while (index >= wildcards.size()) { + wildcards.push_back(nullptr); } + if (!wildcards[index]) { + // new wildcard + wildcards[index] = subSeen; // NB: no need to copy + return true; + } else { + // We are seeing this index for a second or later time, check it matches + return ExpressionAnalyzer::equal(subSeen, wildcards[index]); + }; + }; + if (call->target == I32_EXPR) { + if (checkMatch(i32)) return true; + } else if (call->target == I64_EXPR) { + if (checkMatch(i64)) return true; + } else if (call->target == F32_EXPR) { + if (checkMatch(f32)) return true; + } else if (call->target == F64_EXPR) { + if (checkMatch(f64)) return true; + } else if (call->target == ANY_EXPR) { + if (checkMatch(none)) return true; } + return false; + } + + // Applying/copying + + // Apply the match, generate an output expression from the matched input, performing substitutions as necessary + Expression* apply() { + return ExpressionManipulator::flexibleCopy(pattern.output, wasm, *this); } - void visitUnary(Unary* curr) { - if (curr->op == EqZInt32) { - // fold comparisons that flow into an EqZ - auto* child = curr->value->dynCast<Binary>(); - if (child && (child->type == i32 || child->type == i64)) { - switch (child->op) { - case EqInt32: child->op = NeInt32; break; - case NeInt32: child->op = EqInt32; break; - case LtSInt32: child->op = GeSInt32; break; - case LtUInt32: child->op = GeUInt32; break; - case LeSInt32: child->op = GtSInt32; break; - case LeUInt32: child->op = GtUInt32; break; - case GtSInt32: child->op = LeSInt32; break; - case GtUInt32: child->op = LeUInt32; break; - case GeSInt32: child->op = LtSInt32; break; - case GeUInt32: child->op = LtUInt32; break; - case EqInt64: child->op = NeInt64; break; - case NeInt64: child->op = EqInt64; break; - case LtSInt64: child->op = GeSInt64; break; - case LtUInt64: child->op = GeUInt64; break; - case LeSInt64: child->op = GtSInt64; break; - case LeUInt64: child->op = GtUInt64; break; - case GtSInt64: child->op = LeSInt64; break; - case GtUInt64: child->op = LeUInt64; break; - case GeSInt64: child->op = LtSInt64; break; - case GeUInt64: child->op = LtUInt64; break; - case EqFloat32: child->op = NeFloat32; break; - case NeFloat32: child->op = EqFloat32; break; - case LtFloat32: child->op = GeFloat32; break; - case LeFloat32: child->op = GtFloat32; break; - case GtFloat32: child->op = LeFloat32; break; - case GeFloat32: child->op = LtFloat32; break; - case EqFloat64: child->op = NeFloat64; break; - case NeFloat64: child->op = EqFloat64; break; - case LtFloat64: child->op = GeFloat64; break; - case LeFloat64: child->op = GtFloat64; break; - case GtFloat64: child->op = LeFloat64; break; - case GeFloat64: child->op = LtFloat64; break; - default: return; + + // When copying a wildcard, perform the substitution. + // TODO: we can reuse nodes, not copying a wildcard when it appears just once, and we can reuse other individual nodes when they are discarded anyhow. + Expression* copy(Expression* curr) { + CallImport* call = curr->dynCast<CallImport>(); + if (!call || call->operands.size() != 1 || call->operands[0]->type != i32 || !call->operands[0]->is<Const>()) return nullptr; + Index index = call->operands[0]->cast<Const>()->value.geti32(); + // handle our special functions + if (call->target == I32_EXPR || call->target == I64_EXPR || call->target == F32_EXPR || call->target == F64_EXPR || call->target == ANY_EXPR) { + return ExpressionManipulator::copy(wildcards.at(index), wasm); + } + return nullptr; + } +}; + +// Main pass class +struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, UnifiedExpressionVisitor<OptimizeInstructions>>> { + bool isFunctionParallel() override { return true; } + + Pass* create() override { return new OptimizeInstructions; } + + void visitExpression(Expression* curr) { + // we may be able to apply multiple patterns, one may open opportunities that look deeper NB: patterns must not have cycles + while (1) { + auto iter = database.patternMap.find(curr->_id); + if (iter == database.patternMap.end()) return; + auto& patterns = iter->second; + bool more = false; + for (auto& pattern : patterns) { + Match match(*getModule(), pattern); + if (match.check(curr)) { + curr = match.apply(); + replaceCurrent(curr); + more = true; + break; // exit pattern for loop, return to main while loop } - replaceCurrent(child); } + if (!more) break; } } }; |