diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/ast/bits.h | 20 | ||||
-rw-r--r-- | src/ast/literal-utils.h | 26 | ||||
-rw-r--r-- | src/passes/MergeBlocks.cpp | 34 | ||||
-rw-r--r-- | src/passes/OptimizeInstructions.cpp | 46 | ||||
-rw-r--r-- | src/wasm/wasm-binary.cpp | 13 |
5 files changed, 103 insertions, 36 deletions
diff --git a/src/ast/bits.h b/src/ast/bits.h index d88cb5edb..11cf7b06d 100644 --- a/src/ast/bits.h +++ b/src/ast/bits.h @@ -39,6 +39,26 @@ struct Bits { // this is indeed a mask return 32 - CountLeadingZeroes(mask); } + + // gets the number of effective shifts a shift operation does. In + // wasm, only 5 bits matter for 32-bit shifts, and 6 for 64. + static Index getEffectiveShifts(Index amount, WasmType type) { + if (type == i32) { + return amount & 31; + } else if (type == i64) { + return amount & 63; + } + WASM_UNREACHABLE(); + } + + static Index getEffectiveShifts(Const* amount) { + if (amount->type == i32) { + return getEffectiveShifts(amount->value.geti32(), i32); + } else if (amount->type == i64) { + return getEffectiveShifts(amount->value.geti64(), i64); + } + WASM_UNREACHABLE(); + } }; } // namespace wasm diff --git a/src/ast/literal-utils.h b/src/ast/literal-utils.h index 7e75e8bc8..afa8146b9 100644 --- a/src/ast/literal-utils.h +++ b/src/ast/literal-utils.h @@ -23,21 +23,31 @@ namespace wasm { namespace LiteralUtils { -inline Expression* makeZero(WasmType type, Module& wasm) { - Literal value; +inline Literal makeLiteralFromInt32(int32_t x, WasmType type) { switch (type) { - case i32: value = Literal(int32_t(0)); break; - case i64: value = Literal(int64_t(0)); break; - case f32: value = Literal(float(0)); break; - case f64: value = Literal(double(0)); break; + case i32: return Literal(int32_t(x)); break; + case i64: return Literal(int64_t(x)); break; + case f32: return Literal(float(x)); break; + case f64: return Literal(double(x)); break; default: WASM_UNREACHABLE(); } +} + +inline Literal makeLiteralZero(WasmType type) { + return makeLiteralFromInt32(0, type); +} + +inline Expression* makeFromInt32(int32_t x, WasmType type, Module& wasm) { auto* ret = wasm.allocator.alloc<Const>(); - ret->value = value; - ret->type = value.type; + ret->value = makeLiteralFromInt32(x, type); + ret->type = type; return ret; } +inline Expression* makeZero(WasmType type, Module& wasm) { + return makeFromInt32(0, type, wasm); +} + } // namespace LiteralUtils } // namespace wasm diff --git a/src/passes/MergeBlocks.cpp b/src/passes/MergeBlocks.cpp index d32948bee..7c086a920 100644 --- a/src/passes/MergeBlocks.cpp +++ b/src/passes/MergeBlocks.cpp @@ -73,14 +73,23 @@ namespace wasm { // For example, if there is a switch targeting us, we can't do it - we can't remove the value from other targets struct ProblemFinder : public ControlFlowWalker<ProblemFinder> { Name origin; - bool foundSwitch = false; + bool foundProblem = false; // count br_ifs, and dropped br_ifs. if they don't match, then a br_if flow value is used, and we can't drop it Index brIfs = 0; Index droppedBrIfs = 0; + PassOptions& passOptions; + + ProblemFinder(PassOptions& passOptions) : passOptions(passOptions) {} void visitBreak(Break* curr) { - if (curr->name == origin && curr->condition) { - brIfs++; + if (curr->name == origin) { + if (curr->condition) { + brIfs++; + } + // if the value has side effects, we can't remove it + if (EffectAnalyzer(passOptions, curr->value).hasSideEffects()) { + foundProblem = true; + } } } @@ -94,12 +103,12 @@ struct ProblemFinder : public ControlFlowWalker<ProblemFinder> { void visitSwitch(Switch* curr) { if (curr->default_ == origin) { - foundSwitch = true; + foundProblem = true; return; } for (auto& target : curr->targets) { if (target == origin) { - foundSwitch = true; + foundProblem = true; return; } } @@ -107,7 +116,7 @@ struct ProblemFinder : public ControlFlowWalker<ProblemFinder> { bool found() { assert(brIfs >= droppedBrIfs); - return foundSwitch || brIfs > droppedBrIfs; + return foundProblem || brIfs > droppedBrIfs; } }; @@ -115,6 +124,9 @@ struct ProblemFinder : public ControlFlowWalker<ProblemFinder> { // While doing so it can create new blocks, so optimize blocks as well. struct BreakValueDropper : public ControlFlowWalker<BreakValueDropper> { Name origin; + PassOptions& passOptions; + + BreakValueDropper(PassOptions& passOptions) : passOptions(passOptions) {} void visitBlock(Block* curr); @@ -143,7 +155,7 @@ struct BreakValueDropper : public ControlFlowWalker<BreakValueDropper> { }; // core block optimizer routine -static void optimizeBlock(Block* curr, Module* module) { +static void optimizeBlock(Block* curr, Module* module, PassOptions& passOptions) { bool more = true; bool changed = false; while (more) { @@ -159,14 +171,14 @@ static void optimizeBlock(Block* curr, Module* module) { if (child->name.is()) { Expression* expression = child; // check if it's ok to remove the value from all breaks to us - ProblemFinder finder; + ProblemFinder finder(passOptions); finder.origin = child->name; finder.walk(expression); if (finder.found()) { child = nullptr; } else { // fix up breaks - BreakValueDropper fixer; + BreakValueDropper fixer(passOptions); fixer.origin = child->name; fixer.setModule(module); fixer.walk(expression); @@ -217,7 +229,7 @@ static void optimizeBlock(Block* curr, Module* module) { } void BreakValueDropper::visitBlock(Block* curr) { - optimizeBlock(curr, getModule()); + optimizeBlock(curr, getModule(), passOptions); } struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks>> { @@ -226,7 +238,7 @@ struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks>> { Pass* create() override { return new MergeBlocks; } void visitBlock(Block *curr) { - optimizeBlock(curr, getModule()); + optimizeBlock(curr, getModule(), getPassOptions()); } Block* optimize(Expression* curr, Expression*& child, Block* outer = nullptr, Expression** dependency1 = nullptr, Expression** dependency2 = nullptr) { diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index de6f96ccc..4fd7cbe8d 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -29,6 +29,7 @@ #include <ast/effects.h> #include <ast/manipulation.h> #include <ast/properties.h> +#include <ast/literal-utils.h> namespace wasm { @@ -188,14 +189,14 @@ Index getMaxBits(Expression* curr, LocalInfoProvider* localInfoProvider) { case OrInt32: case XorInt32: return std::max(getMaxBits(binary->left, localInfoProvider), getMaxBits(binary->right, localInfoProvider)); case ShlInt32: { if (auto* shifts = binary->right->dynCast<Const>()) { - return std::min(Index(32), getMaxBits(binary->left, localInfoProvider) + shifts->value.geti32()); + return std::min(Index(32), getMaxBits(binary->left, localInfoProvider) + Bits::getEffectiveShifts(shifts)); } return 32; } case ShrUInt32: { if (auto* shift = binary->right->dynCast<Const>()) { auto maxBits = getMaxBits(binary->left, localInfoProvider); - auto shifts = std::min(Index(shift->value.geti32()), maxBits); // can ignore more shifts than zero us out + auto shifts = std::min(Index(Bits::getEffectiveShifts(shift)), maxBits); // can ignore more shifts than zero us out return std::max(Index(0), maxBits - shifts); } return 32; @@ -204,7 +205,7 @@ Index getMaxBits(Expression* curr, LocalInfoProvider* localInfoProvider) { if (auto* shift = binary->right->dynCast<Const>()) { auto maxBits = getMaxBits(binary->left, localInfoProvider); if (maxBits == 32) return 32; - auto shifts = std::min(Index(shift->value.geti32()), maxBits); // can ignore more shifts than zero us out + auto shifts = std::min(Index(Bits::getEffectiveShifts(shift)), maxBits); // can ignore more shifts than zero us out return std::max(Index(0), maxBits - shifts); } return 32; @@ -533,9 +534,16 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, } else if (left->op == OrInt32) { leftRight->value = leftRight->value.or_(right->value); return left; - } else if (left->op == ShlInt32 || left->op == ShrUInt32 || left->op == ShrSInt32) { - leftRight->value = leftRight->value.add(right->value); - return left; + } else if (left->op == ShlInt32 || left->op == ShrUInt32 || left->op == ShrSInt32 || + left->op == ShlInt64 || left->op == ShrUInt64 || left->op == ShrSInt64) { + // shifts only use an effective amount from the constant, so adding must + // be done carefully + auto total = Bits::getEffectiveShifts(leftRight) + Bits::getEffectiveShifts(right); + if (total == Bits::getEffectiveShifts(total, right->type)) { + // no overflow, we can do this + leftRight->value = LiteralUtils::makeLiteralFromInt32(total, right->type); + return left; + } // TODO: handle overflows } } } @@ -874,14 +882,17 @@ private: auto* left = binary->left; auto* right = binary->right; if (!Properties::emitsBoolean(left) || !Properties::emitsBoolean(right)) return nullptr; - auto leftEffects = EffectAnalyzer(getPassOptions(), left).hasSideEffects(); - auto rightEffects = EffectAnalyzer(getPassOptions(), right).hasSideEffects(); - if (leftEffects && rightEffects) return nullptr; // both must execute - // canonicalize with side effects, if any, happening on the left - if (rightEffects) { + auto leftEffects = EffectAnalyzer(getPassOptions(), left); + auto rightEffects = EffectAnalyzer(getPassOptions(), right); + auto leftHasSideEffects = leftEffects.hasSideEffects(); + auto rightHasSideEffects = rightEffects.hasSideEffects(); + if (leftHasSideEffects && rightHasSideEffects) return nullptr; // both must execute + // canonicalize with side effects, if any, happening on the left + if (rightHasSideEffects) { if (CostAnalyzer(left).cost < MIN_COST) return nullptr; // avoidable code is too cheap + if (leftEffects.invalidates(rightEffects)) return nullptr; // cannot reorder std::swap(left, right); - } else if (leftEffects) { + } else if (leftHasSideEffects) { if (CostAnalyzer(right).cost < MIN_COST) return nullptr; // avoidable code is too cheap } else { // no side effects, reorder based on cost estimation @@ -908,8 +919,15 @@ private: // it's better to do the opposite for gzip purposes as well as for readability. auto* last = ptr->dynCast<Const>(); if (last) { - last->value = Literal(int32_t(last->value.geti32() + offset)); - offset = 0; + // don't do this if it would wrap the pointer + uint64_t value64 = last->value.geti32(); + uint64_t offset64 = offset; + if (value64 <= std::numeric_limits<int32_t>::max() && + offset64 <= std::numeric_limits<int32_t>::max() && + value64 + offset64 <= std::numeric_limits<int32_t>::max()) { + last->value = Literal(int32_t(value64 + offset64)); + offset = 0; + } } } diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 346018c7d..88b967fdf 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -545,7 +545,7 @@ void WasmBinaryWriter::recurse(Expression*& curr) { } static bool brokenTo(Block* block) { - return block->name.is() && BranchUtils::BranchSeeker::has(block, block->name); + return block->name.is() && BranchUtils::BranchSeeker::hasNamed(block, block->name); } void WasmBinaryWriter::visitBlock(Block *curr) { @@ -636,7 +636,7 @@ int32_t WasmBinaryWriter::getBreakIndex(Name name) { // -1 if not found return breakStack.size() - 1 - i; } } - std::cerr << "bad break: " << name << std::endl; + std::cerr << "bad break: " << name << " in " << currFunction->name << std::endl; abort(); } @@ -2102,7 +2102,14 @@ void WasmBinaryBuilder::visitBlock(Block *curr) { } for (size_t i = start; i < end; i++) { if (debug) std::cerr << " " << size_t(expressionStack[i]) << "\n zz Block element " << curr->list.size() << std::endl; - curr->list.push_back(expressionStack[i]); + auto* item = expressionStack[i]; + curr->list.push_back(item); + if (i < end - 1) { + // stacky&unreachable code may introduce elements that need to be dropped in non-final positions + if (isConcreteWasmType(item->type)) { + curr->list.back() = Builder(wasm).makeDrop(curr->list.back()); + } + } } expressionStack.resize(start); curr->finalize(curr->type); |