diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/ir/iteration.h | 213 | ||||
-rw-r--r-- | src/ir/properties.h | 193 | ||||
-rw-r--r-- | src/tools/wasm-reduce.cpp | 137 | ||||
-rw-r--r-- | src/wasm.h | 2 |
4 files changed, 409 insertions, 136 deletions
diff --git a/src/ir/iteration.h b/src/ir/iteration.h new file mode 100644 index 000000000..6cf149894 --- /dev/null +++ b/src/ir/iteration.h @@ -0,0 +1,213 @@ +/* + * Copyright 2018 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ir_iteration_h +#define wasm_ir_iteration_h + +#include "wasm.h" + +namespace wasm { + +// +// Allows iteration over the children of the expression, in order of execution +// where relevant. +// +// * This skips missing children, e.g. if an if has no else, it is represented +// as having 2 children (and not 3 with the last a nullptr). +// +// In general, it is preferable not to use this class and to directly access +// the children (using e.g. iff->ifTrue etc.), as that is faster. However, in +// cases where speed does not matter, this can be convenient. +// + +class ChildIterator { + struct Iterator { + const ChildIterator& parent; + Index index; + + Iterator(const ChildIterator& parent, Index index) : parent(parent), index(index) {} + + bool operator!=(const Iterator& other) const { + return index != other.index || &parent != &(other.parent); + } + + void operator++() { + index++; + } + + Expression* operator*() { + return parent.children[index]; + } + }; + +public: + std::vector<Expression*> children; + + ChildIterator(Expression* expr) { + switch (expr->_id) { + case Expression::Id::BlockId: { + auto& list = expr->cast<Block>()->list; + for (auto* child : list) { + children.push_back(child); + } + break; + } + case Expression::Id::IfId: { + auto* iff = expr->cast<If>(); + children.push_back(iff->condition); + children.push_back(iff->ifTrue); + if (iff->ifFalse) children.push_back(iff->ifFalse); + break; + } + case Expression::Id::LoopId: { + children.push_back(expr->cast<Loop>()->body); + break; + } + case Expression::Id::BreakId: { + auto* br = expr->cast<Break>(); + if (br->value) children.push_back(br->value); + if (br->condition) children.push_back(br->condition); + break; + } + case Expression::Id::SwitchId: { + auto* br = expr->cast<Switch>(); + if (br->value) children.push_back(br->value); + children.push_back(br->condition); + break; + } + case Expression::Id::CallId: { + auto& operands = expr->cast<Call>()->operands; + for (auto* child : operands) { + children.push_back(child); + } + break; + } + case Expression::Id::CallImportId: { + auto& operands = expr->cast<CallImport>()->operands; + for (auto* child : operands) { + children.push_back(child); + } + break; + } + case Expression::Id::CallIndirectId: { + auto* call = expr->cast<CallIndirect>(); + auto& operands = call->operands; + for (auto* child : operands) { + children.push_back(child); + } + children.push_back(call->target); + break; + } + case Expression::Id::SetLocalId: { + children.push_back(expr->cast<SetLocal>()->value); + break; + } + case Expression::Id::SetGlobalId: { + children.push_back(expr->cast<SetGlobal>()->value); + break; + } + case Expression::Id::LoadId: { + children.push_back(expr->cast<Load>()->ptr); + break; + } + case Expression::Id::StoreId: { + auto* store = expr->cast<Store>(); + children.push_back(store->ptr); + children.push_back(store->value); + break; + } + case Expression::Id::UnaryId: { + children.push_back(expr->cast<Unary>()->value); + break; + } + case Expression::Id::BinaryId: { + auto* binary = expr->cast<Binary>(); + children.push_back(binary->left); + children.push_back(binary->right); + break; + } + case Expression::Id::SelectId: { + auto* select = expr->cast<Select>(); + children.push_back(select->ifTrue); + children.push_back(select->ifFalse); + children.push_back(select->condition); + break; + } + case Expression::Id::DropId: { + children.push_back(expr->cast<Drop>()->value); + break; + } + case Expression::Id::ReturnId: { + auto* ret = expr->dynCast<Return>(); + if (ret->value) children.push_back(ret->value); + break; + } + case Expression::Id::HostId: { + auto& operands = expr->cast<Host>()->operands; + for (auto* child : operands) { + children.push_back(child); + } + break; + } + case Expression::Id::AtomicRMWId: { + auto* atomic = expr->cast<AtomicRMW>(); + children.push_back(atomic->ptr); + children.push_back(atomic->value); + break; + } + case Expression::Id::AtomicCmpxchgId: { + auto* atomic = expr->cast<AtomicCmpxchg>(); + children.push_back(atomic->ptr); + children.push_back(atomic->expected); + children.push_back(atomic->replacement); + break; + } + case Expression::Id::AtomicWaitId: { + auto* atomic = expr->cast<AtomicWait>(); + children.push_back(atomic->ptr); + children.push_back(atomic->expected); + children.push_back(atomic->timeout); + break; + } + case Expression::Id::AtomicWakeId: { + auto* atomic = expr->cast<AtomicWake>(); + children.push_back(atomic->ptr); + children.push_back(atomic->wakeCount); + break; + } + case Expression::Id::GetLocalId: + case Expression::Id::GetGlobalId: + case Expression::Id::ConstId: + case Expression::Id::NopId: + case Expression::Id::UnreachableId: { + break; // no children + } + default: WASM_UNREACHABLE(); + } + } + + Iterator begin() const { + return Iterator(*this, 0); + } + Iterator end() const { + return Iterator(*this, children.size()); + } +}; + +} // wasm + +#endif // wasm_ir_iteration_h + diff --git a/src/ir/properties.h b/src/ir/properties.h index cf481218c..7cbbe73f0 100644 --- a/src/ir/properties.h +++ b/src/ir/properties.h @@ -22,51 +22,62 @@ namespace wasm { -struct Properties { - static bool emitsBoolean(Expression* curr) { - if (auto* unary = curr->dynCast<Unary>()) { - return unary->isRelational(); - } else if (auto* binary = curr->dynCast<Binary>()) { - return binary->isRelational(); - } - return false; - } +namespace Properties { - static bool isSymmetric(Binary* binary) { - switch (binary->op) { - case AddInt32: - case MulInt32: - case AndInt32: - case OrInt32: - case XorInt32: - case EqInt32: - case NeInt32: - - case AddInt64: - case MulInt64: - case AndInt64: - case OrInt64: - case XorInt64: - case EqInt64: - case NeInt64: return true; - - default: return false; - } +inline bool emitsBoolean(Expression* curr) { + if (auto* unary = curr->dynCast<Unary>()) { + return unary->isRelational(); + } else if (auto* binary = curr->dynCast<Binary>()) { + return binary->isRelational(); } - - // Check if an expression is a sign-extend, and if so, returns the value - // that is extended, otherwise nullptr - static Expression* getSignExtValue(Expression* curr) { - if (auto* outer = curr->dynCast<Binary>()) { - if (outer->op == ShrSInt32) { - if (auto* outerConst = outer->right->dynCast<Const>()) { - if (outerConst->value.geti32() != 0) { - if (auto* inner = outer->left->dynCast<Binary>()) { - if (inner->op == ShlInt32) { - if (auto* innerConst = inner->right->dynCast<Const>()) { - if (outerConst->value == innerConst->value) { - return inner->left; - } + return false; +} + +inline bool isSymmetric(Binary* binary) { + switch (binary->op) { + case AddInt32: + case MulInt32: + case AndInt32: + case OrInt32: + case XorInt32: + case EqInt32: + case NeInt32: + + case AddInt64: + case MulInt64: + case AndInt64: + case OrInt64: + case XorInt64: + case EqInt64: + case NeInt64: return true; + + default: return false; + } +} + +// Check if an expression is a control flow construct with a name, +// which implies it may have breaks to it. +inline bool isNamedControlFlow(Expression* curr) { + if (auto* block = curr->dynCast<Block>()) { + return block->name.is(); + } else if (auto* loop = curr->dynCast<Loop>()) { + return loop->name.is(); + } + return false; +} + +// Check if an expression is a sign-extend, and if so, returns the value +// that is extended, otherwise nullptr +inline Expression* getSignExtValue(Expression* curr) { + if (auto* outer = curr->dynCast<Binary>()) { + if (outer->op == ShrSInt32) { + if (auto* outerConst = outer->right->dynCast<Const>()) { + if (outerConst->value.geti32() != 0) { + if (auto* inner = outer->left->dynCast<Binary>()) { + if (inner->op == ShlInt32) { + if (auto* innerConst = inner->right->dynCast<Const>()) { + if (outerConst->value == innerConst->value) { + return inner->left; } } } @@ -74,28 +85,28 @@ struct Properties { } } } - return nullptr; - } - - // gets the size of the sign-extended value - static Index getSignExtBits(Expression* curr) { - return 32 - Bits::getEffectiveShifts(curr->cast<Binary>()->right); } - - // Check if an expression is almost a sign-extend: perhaps the inner shift - // is too large. We can split the shifts in that case, which is sometimes - // useful (e.g. if we can remove the signext) - static Expression* getAlmostSignExt(Expression* curr) { - if (auto* outer = curr->dynCast<Binary>()) { - if (outer->op == ShrSInt32) { - if (auto* outerConst = outer->right->dynCast<Const>()) { - if (outerConst->value.geti32() != 0) { - if (auto* inner = outer->left->dynCast<Binary>()) { - if (inner->op == ShlInt32) { - if (auto* innerConst = inner->right->dynCast<Const>()) { - if (Bits::getEffectiveShifts(outerConst) <= Bits::getEffectiveShifts(innerConst)) { - return inner->left; - } + return nullptr; +} + +// gets the size of the sign-extended value +inline Index getSignExtBits(Expression* curr) { + return 32 - Bits::getEffectiveShifts(curr->cast<Binary>()->right); +} + +// Check if an expression is almost a sign-extend: perhaps the inner shift +// is too large. We can split the shifts in that case, which is sometimes +// useful (e.g. if we can remove the signext) +inline Expression* getAlmostSignExt(Expression* curr) { + if (auto* outer = curr->dynCast<Binary>()) { + if (outer->op == ShrSInt32) { + if (auto* outerConst = outer->right->dynCast<Const>()) { + if (outerConst->value.geti32() != 0) { + if (auto* inner = outer->left->dynCast<Binary>()) { + if (inner->op == ShlInt32) { + if (auto* innerConst = inner->right->dynCast<Const>()) { + if (Bits::getEffectiveShifts(outerConst) <= Bits::getEffectiveShifts(innerConst)) { + return inner->left; } } } @@ -103,39 +114,41 @@ struct Properties { } } } - return nullptr; - } - - // gets the size of the almost sign-extended value, as well as the - // extra shifts, if any - static Index getAlmostSignExtBits(Expression* curr, Index& extraShifts) { - extraShifts = Bits::getEffectiveShifts(curr->cast<Binary>()->left->cast<Binary>()->right) - - Bits::getEffectiveShifts(curr->cast<Binary>()->right); - return getSignExtBits(curr); } - - // Check if an expression is a zero-extend, and if so, returns the value - // that is extended, otherwise nullptr - static Expression* getZeroExtValue(Expression* curr) { - if (auto* binary = curr->dynCast<Binary>()) { - if (binary->op == AndInt32) { - if (auto* c = binary->right->dynCast<Const>()) { - if (Bits::getMaskedBits(c->value.geti32())) { - return binary->right; - } + return nullptr; +} + +// gets the size of the almost sign-extended value, as well as the +// extra shifts, if any +inline Index getAlmostSignExtBits(Expression* curr, Index& extraShifts) { + extraShifts = Bits::getEffectiveShifts(curr->cast<Binary>()->left->cast<Binary>()->right) - + Bits::getEffectiveShifts(curr->cast<Binary>()->right); + return getSignExtBits(curr); +} + +// Check if an expression is a zero-extend, and if so, returns the value +// that is extended, otherwise nullptr +inline Expression* getZeroExtValue(Expression* curr) { + if (auto* binary = curr->dynCast<Binary>()) { + if (binary->op == AndInt32) { + if (auto* c = binary->right->dynCast<Const>()) { + if (Bits::getMaskedBits(c->value.geti32())) { + return binary->right; } } } - return nullptr; } + return nullptr; +} - // gets the size of the sign-extended value - static Index getZeroExtBits(Expression* curr) { - return Bits::getMaskedBits(curr->cast<Binary>()->right->cast<Const>()->value.geti32()); - } -}; +// gets the size of the sign-extended value +inline Index getZeroExtBits(Expression* curr) { + return Bits::getMaskedBits(curr->cast<Binary>()->right->cast<Const>()->value.geti32()); +} + +} // Properties } // wasm -#endif // wams_ir_properties_h +#endif // wasm_ir_properties_h diff --git a/src/tools/wasm-reduce.cpp b/src/tools/wasm-reduce.cpp index 7b7ff9da3..0dfb97907 100644 --- a/src/tools/wasm-reduce.cpp +++ b/src/tools/wasm-reduce.cpp @@ -35,7 +35,9 @@ #include "wasm-io.h" #include "wasm-builder.h" #include "ir/branch-utils.h" +#include "ir/iteration.h" #include "ir/literal-utils.h" +#include "ir/properties.h" #include "wasm-validator.h" #ifdef _WIN32 #ifndef NOMINMAX @@ -403,8 +405,9 @@ struct Reducer : public WalkerPass<PostWalker<Reducer, UnifiedExpressionVisitor< // don't need to duplicate work that they do void visitExpression(Expression* curr) { + // type-based reductions if (curr->type == none) { - if (tryToReduceCurrentToNone()) return; + if (tryToReduceCurrentToNop()) return; } else if (isConcreteType(curr->type)) { if (tryToReduceCurrentToConst()) return; } else { @@ -419,8 +422,6 @@ struct Reducer : public WalkerPass<PostWalker<Reducer, UnifiedExpressionVisitor< return; } } - if (tryToReplaceCurrent(iff->ifTrue)) return; - if (iff->ifFalse && tryToReplaceCurrent(iff->ifFalse)) return; handleCondition(iff->condition); } else if (auto* br = curr->dynCast<Break>()) { handleCondition(br->condition); @@ -428,26 +429,6 @@ struct Reducer : public WalkerPass<PostWalker<Reducer, UnifiedExpressionVisitor< handleCondition(select->condition); } else if (auto* sw = curr->dynCast<Switch>()) { handleCondition(sw->condition); - } else if (auto* set = curr->dynCast<SetLocal>()) { - if (set->isTee()) { - // maybe we don't need the set - tryToReplaceCurrent(set->value); - } - } else if (auto* unary = curr->dynCast<Unary>()) { - // maybe we can pass through - tryToReplaceCurrent(unary->value); - } else if (auto* binary = curr->dynCast<Binary>()) { - // maybe we can pass through - if (!tryToReplaceCurrent(binary->left)) { - tryToReplaceCurrent(binary->right); - } - } else if (auto* call = curr->dynCast<Call>()) { - handleCall(call); - } else if (auto* call = curr->dynCast<CallImport>()) { - handleCall(call); - } else if (auto* call = curr->dynCast<CallIndirect>()) { - if (tryToReplaceCurrent(call->target)) return; - handleCall(call); } else if (auto* block = curr->dynCast<Block>()) { if (!shouldTryToReduce()) return; // replace a singleton @@ -479,24 +460,66 @@ struct Reducer : public WalkerPass<PostWalker<Reducer, UnifiedExpressionVisitor< } i++; } + return; // nothing more to do } else if (auto* loop = curr->dynCast<Loop>()) { if (shouldTryToReduce() && !BranchUtils::BranchSeeker::hasNamed(loop, loop->name)) { tryToReplaceCurrent(loop->body); } - } else if (auto* rmw = curr->dynCast<AtomicRMW>()) { - if (tryToReplaceCurrent(rmw->ptr)) return; - if (tryToReplaceCurrent(rmw->value)) return; - } else if (auto* cmpx = curr->dynCast<AtomicCmpxchg>()) { - if (tryToReplaceCurrent(cmpx->ptr)) return; - if (tryToReplaceCurrent(cmpx->expected)) return; - if (tryToReplaceCurrent(cmpx->replacement)) return; - } else if (auto* wait = curr->dynCast<AtomicWait>()) { - if (tryToReplaceCurrent(wait->ptr)) return; - if (tryToReplaceCurrent(wait->expected)) return; - if (tryToReplaceCurrent(wait->timeout)) return; - } else if (auto* wake = curr->dynCast<AtomicWake>()) { - if (tryToReplaceCurrent(wake->ptr)) return; - if (tryToReplaceCurrent(wake->wakeCount)) return; + return; // nothing more to do + } + // Finally, try to replace with a child. + for (auto* child : ChildIterator(curr)) { + if (tryToReplaceCurrent(child)) return; + } + // If that didn't work, try to replace with a child + a unary conversion + if (isConcreteType(curr->type) && + !curr->is<Unary>()) { // but not if it's already unary + for (auto* child : ChildIterator(curr)) { + if (child->type == curr->type) continue; // already tried + if (!isConcreteType(child->type)) continue; // no conversion + Expression* fixed; + switch (curr->type) { + case i32: { + switch (child->type) { + case i64: fixed = builder->makeUnary(WrapInt64, child); break; + case f32: fixed = builder->makeUnary(TruncSFloat32ToInt32, child); break; + case f64: fixed = builder->makeUnary(TruncSFloat64ToInt32, child); break; + default: WASM_UNREACHABLE(); + } + break; + } + case i64: { + switch (child->type) { + case i32: fixed = builder->makeUnary(ExtendSInt32, child); break; + case f32: fixed = builder->makeUnary(TruncSFloat32ToInt64, child); break; + case f64: fixed = builder->makeUnary(TruncSFloat64ToInt64, child); break; + default: WASM_UNREACHABLE(); + } + break; + } + case f32: { + switch (child->type) { + case i32: fixed = builder->makeUnary(ConvertSInt32ToFloat32, child); break; + case i64: fixed = builder->makeUnary(ConvertSInt64ToFloat32, child); break; + case f64: fixed = builder->makeUnary(DemoteFloat64, child); break; + default: WASM_UNREACHABLE(); + } + break; + } + case f64: { + switch (child->type) { + case i32: fixed = builder->makeUnary(ConvertSInt32ToFloat64, child); break; + case i64: fixed = builder->makeUnary(ConvertSInt64ToFloat64, child); break; + case f32: fixed = builder->makeUnary(PromoteFloat32, child); break; + default: WASM_UNREACHABLE(); + } + break; + } + default: WASM_UNREACHABLE(); + } + assert(fixed->type == curr->type); + if (tryToReplaceCurrent(fixed)) return; + } } } @@ -649,6 +672,37 @@ struct Reducer : public WalkerPass<PostWalker<Reducer, UnifiedExpressionVisitor< skip = std::min(size_t(factor), 2 * skip); } } + // If we are left with a single function that is not exported or used in + // a table, that is useful as then we can change the return type. + if (module->functions.size() == 1 && module->exports.empty() && module->table.segments.empty()) { + auto* func = module->functions[0].get(); + // We can't remove something that might have breaks to it. + if (!Properties::isNamedControlFlow(func->body)) { + auto funcType = func->type; + auto funcResult = func->result; + auto* funcBody = func->body; + for (auto* child : ChildIterator(func->body)) { + if (!(isConcreteType(child->type) || child->type == none)) { + continue; // not something a function can return + } + // Try to replace the body with the child, fixing up the function + // to accept it. + func->type = Name(); + func->result = child->type; + func->body = child; + if (writeAndTestReduction()) { + // great, we succeeded! + std::cerr << "| altered function result type\n"; + noteReduction(1); + break; + } + // Undo. + func->type = funcType; + func->result = funcResult; + func->body = funcBody; + } + } + } } bool tryToRemoveFunctions(std::vector<Name> names) { @@ -729,14 +783,7 @@ struct Reducer : public WalkerPass<PostWalker<Reducer, UnifiedExpressionVisitor< } } - template<typename T> - void handleCall(T* call) { - for (auto* op : call->operands) { - if (tryToReplaceCurrent(op)) return; - } - } - - bool tryToReduceCurrentToNone() { + bool tryToReduceCurrentToNop() { auto* curr = getCurrent(); if (curr->is<Nop>()) return false; // try to replace with a trivial value diff --git a/src/wasm.h b/src/wasm.h index 6277d32fc..45881e519 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -196,8 +196,8 @@ public: HostId, NopId, UnreachableId, - AtomicCmpxchgId, AtomicRMWId, + AtomicCmpxchgId, AtomicWaitId, AtomicWakeId, NumExpressionIds |