/* * Copyright 2016 WebAssembly Community Group participants * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // // Optimize combinations of instructions // #include #include #include #include #include #include #include #include #include #include #include namespace wasm { Name I32_EXPR = "i32.expr", I64_EXPR = "i64.expr", F32_EXPR = "f32.expr", F64_EXPR = "f64.expr", ANY_EXPR = "any.expr"; // A pattern struct Pattern { Expression* input; Expression* output; Pattern(Expression* input, Expression* output) : input(input), output(output) {} }; #if 0 // Database of patterns struct PatternDatabase { Module wasm; char* input; std::map> patternMap; // root expression id => list of all patterns for it TODO optimize more PatternDatabase() { // generate module input = strdup( #include "OptimizeInstructions.wast.processed" ); try { SExpressionParser parser(input); Element& root = *parser.root; SExpressionWasmBuilder builder(wasm, *root[0]); // parse module form auto* func = wasm.getFunction("patterns"); auto* body = func->body->cast(); for (auto* item : body->list) { auto* pair = item->cast(); patternMap[pair->list[0]->_id].emplace_back(pair->list[0], pair->list[1]); } } catch (ParseException& p) { p.dump(std::cerr); Fatal() << "error in parsing wasm binary"; } } ~PatternDatabase() { free(input); }; }; static PatternDatabase* database = nullptr; struct DatabaseEnsurer { DatabaseEnsurer() { assert(!database); database = new PatternDatabase; } }; #endif // Check for matches and apply them struct Match { Module& wasm; Pattern& pattern; Match(Module& wasm, Pattern& pattern) : wasm(wasm), pattern(pattern) {} std::vector wildcards; // id in i32.any(id) etc. => the expression it represents in this match // Comparing/checking // Check if we can match to this pattern, updating ourselves with the info if so bool check(Expression* seen) { // compare seen to the pattern input, doing a special operation for our "wildcards" assert(wildcards.size() == 0); auto compare = [this](Expression* subInput, Expression* subSeen) { CallImport* call = subInput->dynCast(); if (!call || call->operands.size() != 1 || call->operands[0]->type != i32 || !call->operands[0]->is()) return false; Index index = call->operands[0]->cast()->value.geti32(); // handle our special functions auto checkMatch = [&](WasmType type) { if (type != none && subSeen->type != type) return false; while (index >= wildcards.size()) { wildcards.push_back(nullptr); } if (!wildcards[index]) { // new wildcard wildcards[index] = subSeen; // NB: no need to copy return true; } else { // We are seeing this index for a second or later time, check it matches return ExpressionAnalyzer::equal(subSeen, wildcards[index]); }; }; if (call->target == I32_EXPR) { if (checkMatch(i32)) return true; } else if (call->target == I64_EXPR) { if (checkMatch(i64)) return true; } else if (call->target == F32_EXPR) { if (checkMatch(f32)) return true; } else if (call->target == F64_EXPR) { if (checkMatch(f64)) return true; } else if (call->target == ANY_EXPR) { if (checkMatch(none)) return true; } return false; }; return ExpressionAnalyzer::flexibleEqual(pattern.input, seen, compare); } // Applying/copying // Apply the match, generate an output expression from the matched input, performing substitutions as necessary Expression* apply() { // When copying a wildcard, perform the substitution. // TODO: we can reuse nodes, not copying a wildcard when it appears just once, and we can reuse other individual nodes when they are discarded anyhow. auto copy = [this](Expression* curr) -> Expression* { CallImport* call = curr->dynCast(); if (!call || call->operands.size() != 1 || call->operands[0]->type != i32 || !call->operands[0]->is()) return nullptr; Index index = call->operands[0]->cast()->value.geti32(); // handle our special functions if (call->target == I32_EXPR || call->target == I64_EXPR || call->target == F32_EXPR || call->target == F64_EXPR || call->target == ANY_EXPR) { return ExpressionManipulator::copy(wildcards.at(index), wasm); } return nullptr; }; return ExpressionManipulator::flexibleCopy(pattern.output, wasm, copy); } }; // Utilities // returns the maximum amount of bits used in an integer expression // not extremely precise (doesn't look into add operands, etc.) // LocalInfoProvider is an optional class that can provide answers about // get_local. template Index getMaxBits(Expression* curr, LocalInfoProvider* localInfoProvider) { if (auto* const_ = curr->dynCast()) { switch (curr->type) { case i32: return 32 - const_->value.countLeadingZeroes().geti32(); case i64: return 64 - const_->value.countLeadingZeroes().geti64(); default: WASM_UNREACHABLE(); } } else if (auto* binary = curr->dynCast()) { switch (binary->op) { // 32-bit case AddInt32: case SubInt32: case MulInt32: case DivSInt32: case DivUInt32: case RemSInt32: case RemUInt32: case RotLInt32: case RotRInt32: return 32; case AndInt32: return std::min(getMaxBits(binary->left, localInfoProvider), getMaxBits(binary->right, localInfoProvider)); case OrInt32: case XorInt32: return std::max(getMaxBits(binary->left, localInfoProvider), getMaxBits(binary->right, localInfoProvider)); case ShlInt32: { if (auto* shifts = binary->right->dynCast()) { return std::min(Index(32), getMaxBits(binary->left, localInfoProvider) + Bits::getEffectiveShifts(shifts)); } return 32; } case ShrUInt32: { if (auto* shift = binary->right->dynCast()) { auto maxBits = getMaxBits(binary->left, localInfoProvider); auto shifts = std::min(Index(Bits::getEffectiveShifts(shift)), maxBits); // can ignore more shifts than zero us out return std::max(Index(0), maxBits - shifts); } return 32; } case ShrSInt32: { if (auto* shift = binary->right->dynCast()) { auto maxBits = getMaxBits(binary->left, localInfoProvider); if (maxBits == 32) return 32; auto shifts = std::min(Index(Bits::getEffectiveShifts(shift)), maxBits); // can ignore more shifts than zero us out return std::max(Index(0), maxBits - shifts); } return 32; } // 64-bit TODO // comparisons case EqInt32: case NeInt32: case LtSInt32: case LtUInt32: case LeSInt32: case LeUInt32: case GtSInt32: case GtUInt32: case GeSInt32: case GeUInt32: case EqInt64: case NeInt64: case LtSInt64: case LtUInt64: case LeSInt64: case LeUInt64: case GtSInt64: case GtUInt64: case GeSInt64: case GeUInt64: case EqFloat32: case NeFloat32: case LtFloat32: case LeFloat32: case GtFloat32: case GeFloat32: case EqFloat64: case NeFloat64: case LtFloat64: case LeFloat64: case GtFloat64: case GeFloat64: return 1; default: {} } } else if (auto* unary = curr->dynCast()) { switch (unary->op) { case ClzInt32: case CtzInt32: case PopcntInt32: return 6; case ClzInt64: case CtzInt64: case PopcntInt64: return 7; case EqZInt32: case EqZInt64: return 1; case WrapInt64: return std::min(Index(32), getMaxBits(unary->value, localInfoProvider)); default: {} } } else if (auto* set = curr->dynCast()) { // a tee passes through the value return getMaxBits(set->value, localInfoProvider); } else if (auto* get = curr->dynCast()) { return localInfoProvider->getMaxBitsForLocal(get); } else if (auto* load = curr->dynCast()) { // if signed, then the sign-extension might fill all the bits if (!load->signed_) { return 8 * load->bytes; } } switch (curr->type) { case i32: return 32; case i64: return 64; case unreachable: return 64; // not interesting, but don't crash default: WASM_UNREACHABLE(); } } // looks through fallthrough operations, like tee_local, block fallthrough, etc. // too and block fallthroughs, etc. Expression* getFallthrough(Expression* curr) { if (auto* set = curr->dynCast()) { if (set->isTee()) { return getFallthrough(set->value); } } else if (auto* block = curr->dynCast()) { // if no name, we can't be broken to, and then can look at the fallthrough if (!block->name.is() && block->list.size() > 0) { return getFallthrough(block->list.back()); } } return curr; } // Useful information about locals struct LocalInfo { static const Index kUnknown = Index(-1); Index maxBits; Index signExtedBits; }; struct LocalScanner : PostWalker { std::vector& localInfo; LocalScanner(std::vector& localInfo) : localInfo(localInfo) {} void doWalkFunction(Function* func) { // prepare localInfo.resize(func->getNumLocals()); for (Index i = 0; i < func->getNumLocals(); i++) { auto& info = localInfo[i]; if (func->isParam(i)) { info.maxBits = getBitsForType(func->getLocalType(i)); // worst-case info.signExtedBits = LocalInfo::kUnknown; // we will never know anything } else { info.maxBits = info.signExtedBits = 0; // we are open to learning } } // walk PostWalker::doWalkFunction(func); // finalize for (Index i = 0; i < func->getNumLocals(); i++) { auto& info = localInfo[i]; if (info.signExtedBits == LocalInfo::kUnknown) { info.signExtedBits = 0; } } } void visitSetLocal(SetLocal* curr) { auto* func = getFunction(); if (func->isParam(curr->index)) return; auto type = getFunction()->getLocalType(curr->index); if (type != i32 && type != i64) return; // an integer var, worth processing auto* value = getFallthrough(curr->value); auto& info = localInfo[curr->index]; info.maxBits = std::max(info.maxBits, getMaxBits(value, this)); auto signExtBits = LocalInfo::kUnknown; if (Properties::getSignExtValue(value)) { signExtBits = Properties::getSignExtBits(value); } else if (auto* load = value->dynCast()) { if (load->signed_) { signExtBits = load->bytes * 8; } } if (info.signExtedBits == 0) { info.signExtedBits = signExtBits; // first info we see } else if (info.signExtedBits != signExtBits) { info.signExtedBits = LocalInfo::kUnknown; // contradictory information, give up } } // define this for the templated getMaxBits method. we know nothing here yet about locals, so return the maxes Index getMaxBitsForLocal(GetLocal* get) { return getBitsForType(get->type); } Index getBitsForType(WasmType type) { switch (type) { case i32: return 32; case i64: return 64; default: return -1; } } }; // Main pass class struct OptimizeInstructions : public WalkerPass>> { bool isFunctionParallel() override { return true; } Pass* create() override { return new OptimizeInstructions; } void prepareToRun(PassRunner* runner, Module* module) override { #if 0 static DatabaseEnsurer ensurer; #endif } void doWalkFunction(Function* func) { // first, scan locals { LocalScanner scanner(localInfo); scanner.walkFunction(func); } // main walk WalkerPass>>::doWalkFunction(func); } void visitExpression(Expression* curr) { // we may be able to apply multiple patterns, one may open opportunities that look deeper NB: patterns must not have cycles while (1) { auto* handOptimized = handOptimize(curr); if (handOptimized) { curr = handOptimized; replaceCurrent(curr); continue; } #if 0 auto iter = database->patternMap.find(curr->_id); if (iter == database->patternMap.end()) return; auto& patterns = iter->second; bool more = false; for (auto& pattern : patterns) { Match match(*getModule(), pattern); if (match.check(curr)) { curr = match.apply(); replaceCurrent(curr); more = true; break; // exit pattern for loop, return to main while loop } } if (!more) break; #else break; #endif } } // Optimizations that don't yet fit in the pattern DSL, but could be eventually maybe Expression* handOptimize(Expression* curr) { // if this contains dead code, don't bother trying to optimize it, the type // might change (if might not be unreachable if just one arm is, for example). // this optimization pass focuses on actually executing code. the only // exceptions are control flow changes if (curr->type == unreachable && !curr->is() && !curr->is() && !curr->is()) { return nullptr; } if (auto* binary = curr->dynCast()) { if (Properties::isSymmetric(binary)) { // canonicalize a const to the second position if (binary->left->is() && !binary->right->is()) { std::swap(binary->left, binary->right); } } if (auto* ext = Properties::getAlmostSignExt(binary)) { Index extraShifts; auto bits = Properties::getAlmostSignExtBits(binary, extraShifts); if (extraShifts == 0) { if (auto* load = getFallthrough(ext)->dynCast()) { // pattern match a load of 8 bits and a sign extend using a shl of 24 then shr_s of 24 as well, etc. if ((load->bytes == 1 && bits == 8) || (load->bytes == 2 && bits == 16)) { // if the value falls through, we can't alter the load, as it might be captured in a tee if (load->signed_ == true || load == ext) { load->signed_ = true; return ext; } } } } // if the sign-extend input cannot have a sign bit, we don't need it // we also don't need it if it already has an identical-sized sign extend if (getMaxBits(ext, this) + extraShifts < bits || isSignExted(ext, bits)) { return removeAlmostSignExt(binary); } } else if (binary->op == EqInt32 || binary->op == NeInt32) { if (auto* c = binary->right->dynCast()) { if (binary->op == EqInt32 && c->value.geti32() == 0) { // equal 0 => eqz return Builder(*getModule()).makeUnary(EqZInt32, binary->left); } if (auto* ext = Properties::getSignExtValue(binary->left)) { // we are comparing a sign extend to a constant, which means we can use a cheaper zext auto bits = Properties::getSignExtBits(binary->left); binary->left = makeZeroExt(ext, bits); // when we replace the sign-ext of the non-constant with a zero-ext, we are forcing // the high bits to be all zero, instead of all zero or all one depending on the // sign bit. so we may be changing the high bits from all one to all zero: // * if the constant value's higher bits are mixed, then it can't be equal anyhow // * if they are all zero, we may get a false true if the non-constant's upper bits // were one. this can only happen if the non-constant's sign bit is set, so this // false true is a risk only if the constant's sign bit is set (otherwise, false). // But a constant with a sign bit but with upper bits zero is impossible to be // equal to a sign-extended value anyhow, so the entire thing is false. // * if they were all one, we may get a false false, if the only difference is in // those upper bits. that means we are equal on the other bits, including the sign // bit. so we can just mask off the upper bits in the constant value, in this // case, forcing them to zero like we do in the zero-extend. int32_t constValue = c->value.geti32(); auto upperConstValue = constValue & ~Bits::lowBitMask(bits); uint32_t count = PopCount(upperConstValue); auto constSignBit = constValue & (1 << (bits - 1)); if ((count > 0 && count < 32 - bits) || (constSignBit && count == 0)) { // mixed or [zero upper const bits with sign bit set]; the compared values can never be identical, so // force something definitely impossible even after zext assert(bits < 32); c->value = Literal(int32_t(0x80000000)); // TODO: if no side effects, we can just replace it all with 1 or 0 } else { // otherwise, they are all ones, so we can mask them off as mentioned before c->value = c->value.and_(Literal(Bits::lowBitMask(bits))); } return binary; } } else if (auto* left = Properties::getSignExtValue(binary->left)) { if (auto* right = Properties::getSignExtValue(binary->right)) { auto bits = Properties::getSignExtBits(binary->left); if (Properties::getSignExtBits(binary->right) == bits) { // we are comparing two sign-exts with the same bits, so we may as well replace both with cheaper zexts binary->left = makeZeroExt(left, bits); binary->right = makeZeroExt(right, bits); return binary; } } else if (auto* load = binary->right->dynCast()) { // we are comparing a load to a sign-ext, we may be able to switch to zext auto leftBits = Properties::getSignExtBits(binary->left); if (load->signed_ && leftBits == load->bytes * 8) { load->signed_ = false; binary->left = makeZeroExt(left, leftBits); return binary; } } } else if (auto* load = binary->left->dynCast()) { if (auto* right = Properties::getSignExtValue(binary->right)) { // we are comparing a load to a sign-ext, we may be able to switch to zext auto rightBits = Properties::getSignExtBits(binary->right); if (load->signed_ && rightBits == load->bytes * 8) { load->signed_ = false; binary->right = makeZeroExt(right, rightBits); return binary; } } } // note that both left and right may be consts, but then we let precompute compute the constant result } else if (binary->op == AddInt32 || binary->op == SubInt32) { return optimizeAddedConstants(binary); } // a bunch of operations on a constant right side can be simplified if (auto* right = binary->right->dynCast()) { if (binary->op == AndInt32) { auto mask = right->value.geti32(); // and with -1 does nothing (common in asm.js output) if (mask == -1) { return binary->left; } // small loads do not need to be masted, the load itself masks if (auto* load = binary->left->dynCast()) { if ((load->bytes == 1 && mask == 0xff) || (load->bytes == 2 && mask == 0xffff)) { load->signed_ = false; return binary->left; } } else if (auto maskedBits = Bits::getMaskedBits(mask)) { if (getMaxBits(binary->left, this) <= maskedBits) { // a mask of lower bits is not needed if we are already smaller return binary->left; } } } // some operations have no effect TODO: many more if (right->value == Literal(int32_t(0))) { if (binary->op == ShlInt32 || binary->op == ShrUInt32 || binary->op == ShrSInt32) { return binary->left; } } // the square of some operations can be merged if (auto* left = binary->left->dynCast()) { if (left->op == binary->op) { if (auto* leftRight = left->right->dynCast()) { if (left->op == AndInt32) { leftRight->value = leftRight->value.and_(right->value); return left; } else if (left->op == OrInt32) { leftRight->value = leftRight->value.or_(right->value); return left; } else if (left->op == ShlInt32 || left->op == ShrUInt32 || left->op == ShrSInt32 || left->op == ShlInt64 || left->op == ShrUInt64 || left->op == ShrSInt64) { // shifts only use an effective amount from the constant, so adding must // be done carefully auto total = Bits::getEffectiveShifts(leftRight) + Bits::getEffectiveShifts(right); if (total == Bits::getEffectiveShifts(total, right->type)) { // no overflow, we can do this leftRight->value = LiteralUtils::makeLiteralFromInt32(total, right->type); return left; } // TODO: handle overflows } } } } } if (binary->op == AndInt32 || binary->op == OrInt32) { return conditionalizeExpensiveOnBitwise(binary); } } else if (auto* unary = curr->dynCast()) { // de-morgan's laws if (unary->op == EqZInt32) { if (auto* inner = unary->value->dynCast()) { switch (inner->op) { case EqInt32: inner->op = NeInt32; return inner; case NeInt32: inner->op = EqInt32; return inner; case LtSInt32: inner->op = GeSInt32; return inner; case LtUInt32: inner->op = GeUInt32; return inner; case LeSInt32: inner->op = GtSInt32; return inner; case LeUInt32: inner->op = GtUInt32; return inner; case GtSInt32: inner->op = LeSInt32; return inner; case GtUInt32: inner->op = LeUInt32; return inner; case GeSInt32: inner->op = LtSInt32; return inner; case GeUInt32: inner->op = LtUInt32; return inner; case EqInt64: inner->op = NeInt64; return inner; case NeInt64: inner->op = EqInt64; return inner; case LtSInt64: inner->op = GeSInt64; return inner; case LtUInt64: inner->op = GeUInt64; return inner; case LeSInt64: inner->op = GtSInt64; return inner; case LeUInt64: inner->op = GtUInt64; return inner; case GtSInt64: inner->op = LeSInt64; return inner; case GtUInt64: inner->op = LeUInt64; return inner; case GeSInt64: inner->op = LtSInt64; return inner; case GeUInt64: inner->op = LtUInt64; return inner; case EqFloat32: inner->op = NeFloat32; return inner; case NeFloat32: inner->op = EqFloat32; return inner; case EqFloat64: inner->op = NeFloat64; return inner; case NeFloat64: inner->op = EqFloat64; return inner; default: {} } } // eqz of a sign extension can be of zero-extension if (auto* ext = Properties::getSignExtValue(unary->value)) { // we are comparing a sign extend to a constant, which means we can use a cheaper zext auto bits = Properties::getSignExtBits(unary->value); unary->value = makeZeroExt(ext, bits); return unary; } } } else if (auto* set = curr->dynCast()) { // optimize out a set of a get auto* get = set->value->dynCast(); if (get && get->name == set->name) { ExpressionManipulator::nop(curr); } } else if (auto* iff = curr->dynCast()) { iff->condition = optimizeBoolean(iff->condition); if (iff->ifFalse) { if (auto* unary = iff->condition->dynCast()) { if (unary->op == EqZInt32) { // flip if-else arms to get rid of an eqz iff->condition = unary->value; std::swap(iff->ifTrue, iff->ifFalse); } } if (ExpressionAnalyzer::equal(iff->ifTrue, iff->ifFalse)) { // sides are identical, fold if (!EffectAnalyzer(getPassOptions(), iff->condition).hasSideEffects()) { return iff->ifTrue; } else { Builder builder(*getModule()); return builder.makeSequence( builder.makeDrop(iff->condition), iff->ifTrue ); } } } } else if (auto* select = curr->dynCast