/* * Copyright 2016 WebAssembly Community Group participants * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // // Optimize combinations of instructions // #include #include #include #include #include #include #include namespace wasm { Name I32_EXPR = "i32.expr", I64_EXPR = "i64.expr", F32_EXPR = "f32.expr", F64_EXPR = "f64.expr", ANY_EXPR = "any.expr"; // A pattern struct Pattern { Expression* input; Expression* output; Pattern(Expression* input, Expression* output) : input(input), output(output) {} }; // Database of patterns struct PatternDatabase { Module wasm; char* input; std::map> patternMap; // root expression id => list of all patterns for it TODO optimize more PatternDatabase() { // generate module input = strdup( #include "OptimizeInstructions.wast.processed" ); try { SExpressionParser parser(input); Element& root = *parser.root; SExpressionWasmBuilder builder(wasm, *root[0]); // parse module form auto* func = wasm.getFunction("patterns"); auto* body = func->body->cast(); for (auto* item : body->list) { auto* pair = item->cast(); patternMap[pair->list[0]->_id].emplace_back(pair->list[0], pair->list[1]); } } catch (ParseException& p) { p.dump(std::cerr); Fatal() << "error in parsing wasm binary"; } } ~PatternDatabase() { free(input); }; }; static PatternDatabase* database = nullptr; struct DatabaseEnsurer { DatabaseEnsurer() { assert(!database); database = new PatternDatabase; } }; // Check for matches and apply them struct Match { Module& wasm; Pattern& pattern; Match(Module& wasm, Pattern& pattern) : wasm(wasm), pattern(pattern) {} std::vector wildcards; // id in i32.any(id) etc. => the expression it represents in this match // Comparing/checking // Check if we can match to this pattern, updating ourselves with the info if so bool check(Expression* seen) { // compare seen to the pattern input, doing a special operation for our "wildcards" assert(wildcards.size() == 0); return ExpressionAnalyzer::flexibleEqual(pattern.input, seen, *this); } bool compare(Expression* subInput, Expression* subSeen) { CallImport* call = subInput->dynCast(); if (!call || call->operands.size() != 1 || call->operands[0]->type != i32 || !call->operands[0]->is()) return false; Index index = call->operands[0]->cast()->value.geti32(); // handle our special functions auto checkMatch = [&](WasmType type) { if (type != none && subSeen->type != type) return false; while (index >= wildcards.size()) { wildcards.push_back(nullptr); } if (!wildcards[index]) { // new wildcard wildcards[index] = subSeen; // NB: no need to copy return true; } else { // We are seeing this index for a second or later time, check it matches return ExpressionAnalyzer::equal(subSeen, wildcards[index]); }; }; if (call->target == I32_EXPR) { if (checkMatch(i32)) return true; } else if (call->target == I64_EXPR) { if (checkMatch(i64)) return true; } else if (call->target == F32_EXPR) { if (checkMatch(f32)) return true; } else if (call->target == F64_EXPR) { if (checkMatch(f64)) return true; } else if (call->target == ANY_EXPR) { if (checkMatch(none)) return true; } return false; } // Applying/copying // Apply the match, generate an output expression from the matched input, performing substitutions as necessary Expression* apply() { return ExpressionManipulator::flexibleCopy(pattern.output, wasm, *this); } // When copying a wildcard, perform the substitution. // TODO: we can reuse nodes, not copying a wildcard when it appears just once, and we can reuse other individual nodes when they are discarded anyhow. Expression* copy(Expression* curr) { CallImport* call = curr->dynCast(); if (!call || call->operands.size() != 1 || call->operands[0]->type != i32 || !call->operands[0]->is()) return nullptr; Index index = call->operands[0]->cast()->value.geti32(); // handle our special functions if (call->target == I32_EXPR || call->target == I64_EXPR || call->target == F32_EXPR || call->target == F64_EXPR || call->target == ANY_EXPR) { return ExpressionManipulator::copy(wildcards.at(index), wasm); } return nullptr; } }; // Main pass class struct OptimizeInstructions : public WalkerPass>> { bool isFunctionParallel() override { return true; } Pass* create() override { return new OptimizeInstructions; } void prepareToRun(PassRunner* runner, Module* module) override { static DatabaseEnsurer ensurer; } void visitExpression(Expression* curr) { // we may be able to apply multiple patterns, one may open opportunities that look deeper NB: patterns must not have cycles while (1) { auto* handOptimized = handOptimize(curr); if (handOptimized) { curr = handOptimized; replaceCurrent(curr); continue; } auto iter = database->patternMap.find(curr->_id); if (iter == database->patternMap.end()) return; auto& patterns = iter->second; bool more = false; for (auto& pattern : patterns) { Match match(*getModule(), pattern); if (match.check(curr)) { curr = match.apply(); replaceCurrent(curr); more = true; break; // exit pattern for loop, return to main while loop } } if (!more) break; } } // Optimizations that don't yet fit in the pattern DSL, but could be eventually maybe Expression* handOptimize(Expression* curr) { if (auto* binary = curr->dynCast()) { if (Properties::isSymmetric(binary)) { // canonicalize a const to the second position if (binary->left->is() && !binary->right->is()) { std::swap(binary->left, binary->right); } } // pattern match a load of 8 bits and a sign extend using a shl of 24 then shr_s of 24 as well, etc. if (binary->op == BinaryOp::ShrSInt32 && binary->right->is()) { auto shifts = binary->right->cast()->value.geti32(); if (shifts == 24 || shifts == 16) { auto* left = binary->left->dynCast(); if (left && left->op == ShlInt32 && left->right->is() && left->right->cast()->value.geti32() == shifts) { auto* load = left->left->dynCast(); if (load && ((load->bytes == 1 && shifts == 24) || (load->bytes == 2 && shifts == 16))) { load->signed_ = true; return load; } } } } else if (binary->op == EqInt32) { if (auto* c = binary->right->dynCast()) { if (c->value.geti32() == 0) { // equal 0 => eqz return Builder(*getModule()).makeUnary(EqZInt32, binary->left); } } if (auto* c = binary->left->dynCast()) { if (c->value.geti32() == 0) { // equal 0 => eqz return Builder(*getModule()).makeUnary(EqZInt32, binary->right); } } } else if (binary->op == AndInt32) { if (auto* right = binary->right->dynCast()) { if (right->type == i32) { auto mask = right->value.geti32(); // and with -1 does nothing (common in asm.js output) if (mask == -1) { return binary->left; } // small loads do not need to be masted, the load itself masks if (auto* load = binary->left->dynCast()) { if ((load->bytes == 1 && mask == 0xff) || (load->bytes == 2 && mask == 0xffff)) { load->signed_ = false; return load; } } else if (mask == 1 && Properties::emitsBoolean(binary->left)) { // (bool) & 1 does not need the outer mask return binary->left; } } } } } else if (auto* unary = curr->dynCast()) { // de-morgan's laws if (unary->op == EqZInt32) { if (auto* inner = unary->value->dynCast()) { switch (inner->op) { case EqInt32: inner->op = NeInt32; return inner; case NeInt32: inner->op = EqInt32; return inner; case LtSInt32: inner->op = GeSInt32; return inner; case LtUInt32: inner->op = GeUInt32; return inner; case LeSInt32: inner->op = GtSInt32; return inner; case LeUInt32: inner->op = GtUInt32; return inner; case GtSInt32: inner->op = LeSInt32; return inner; case GtUInt32: inner->op = LeUInt32; return inner; case GeSInt32: inner->op = LtSInt32; return inner; case GeUInt32: inner->op = LtUInt32; return inner; case EqInt64: inner->op = NeInt64; return inner; case NeInt64: inner->op = EqInt64; return inner; case LtSInt64: inner->op = GeSInt64; return inner; case LtUInt64: inner->op = GeUInt64; return inner; case LeSInt64: inner->op = GtSInt64; return inner; case LeUInt64: inner->op = GtUInt64; return inner; case GtSInt64: inner->op = LeSInt64; return inner; case GtUInt64: inner->op = LeUInt64; return inner; case GeSInt64: inner->op = LtSInt64; return inner; case GeUInt64: inner->op = LtUInt64; return inner; case EqFloat32: inner->op = NeFloat32; return inner; case NeFloat32: inner->op = EqFloat32; return inner; case EqFloat64: inner->op = NeFloat64; return inner; case NeFloat64: inner->op = EqFloat64; return inner; default: {} } } } } else if (auto* set = curr->dynCast()) { // optimize out a set of a get auto* get = set->value->dynCast(); if (get && get->name == set->name) { ExpressionManipulator::nop(curr); } } else if (auto* iff = curr->dynCast()) { iff->condition = optimizeBoolean(iff->condition); if (iff->ifFalse) { if (auto* unary = iff->condition->dynCast()) { if (unary->op == EqZInt32) { // flip if-else arms to get rid of an eqz iff->condition = unary->value; std::swap(iff->ifTrue, iff->ifFalse); } } } } else if (auto* select = curr->dynCast