diff options
-rw-r--r-- | src/ir/abstract.h | 8 | ||||
-rw-r--r-- | src/ir/match.h | 800 | ||||
-rw-r--r-- | src/passes/OptimizeInstructions.cpp | 720 | ||||
-rw-r--r-- | test/example/match.cpp | 448 | ||||
-rw-r--r-- | test/example/match.txt | 9 |
5 files changed, 1627 insertions, 358 deletions
diff --git a/src/ir/abstract.h b/src/ir/abstract.h index 5f1fd393e..ce5b6b008 100644 --- a/src/ir/abstract.h +++ b/src/ir/abstract.h @@ -98,9 +98,7 @@ inline UnaryOp getUnary(Type type, Op op) { } break; } - case Type::v128: { - WASM_UNREACHABLE("v128 not implemented yet"); - } + case Type::v128: case Type::funcref: case Type::externref: case Type::exnref: @@ -263,9 +261,7 @@ inline BinaryOp getBinary(Type type, Op op) { } break; } - case Type::v128: { - WASM_UNREACHABLE("v128 not implemented yet"); - } + case Type::v128: case Type::funcref: case Type::externref: case Type::exnref: diff --git a/src/ir/match.h b/src/ir/match.h new file mode 100644 index 000000000..9b7d1ff0f --- /dev/null +++ b/src/ir/match.h @@ -0,0 +1,800 @@ +/* + * Copyright 2020 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +// match.h - Provides an easily extensible layered API for matching expression +// patterns and extracting their components. The low-level API provides modular +// building blocks for creating matchers for any data type and the high-level +// API provides a succinct and flexible interface for matching expressions and +// extracting useful information from them. + +#ifndef wasm_ir_match_h +#define wasm_ir_match_h + +#include "ir/abstract.h" +#include "wasm.h" + +namespace wasm { + +namespace Match { + +// The available matchers are: +// +// i32, i64, f32, f64 +// +// Match constants of the corresponding type. Takes zero or one argument. The +// argument can be a specific value to match or it can be a pointer to a +// value, Literal, or Const* at which to store the matched entity. +// +// ival, fval +// +// Match any integer constant or any floating point constant. Takes neither, +// either, or both of two possible arguments: first, a pointer to a value, +// Literal, or Const* at which to store the matched entity and second, a +// specific value to match. +// +// constant +// +// Matches any numeric Const expression. Takes neither, either, or both of +// two possible arguments: first, a pointer to either Literal or Const* at +// which to store the matched entity and second, a specific value (given as +// an int32_t) to match.. +// +// any +// +// Matches any Expression. Optionally takes as an argument a pointer to +// Expression* at which to store the matched Expression*. +// +// unary +// +// Matches Unary expressions. Takes an optional pointer to Unary* at which to +// store the matched Unary*, followed by either a UnaryOp or an Abstract::Op +// describing which unary expressions to match, followed by a matcher to +// apply to the unary expression's operand. +// +// binary +// +// Matches Binary expressions. Takes an optional pointer to Binary* at which +// to store the matched Binary*, followed by either a BinaryOp or an +// Abstract::Op describing which binary expresions to match, followed by +// matchers to apply to the binary expression's left and right operands. +// +// select +// +// Matches Select expressions. Takes an optional pointer to Select* at which +// to store the matched Select*, followed by matchers to apply to the ifTrue, +// ifFalse, and condition operands. +// +// +// How to create new matchers: +// +// Lets add a matcher for an expression type that is declared in wasm.h: +// +// class Frozzle : public SpecificExpression<Expression::FrozzleId> { +// public: +// Expression* foo; +// Expression* bar; +// Expression* baz; +// }; +// +// This expression is very simple; in order to match it, all we need to do is +// apply other matchers to its subexpressions. The matcher infrastructure will +// handle this automatically once we tell it how to access the subexpressions. +// To tell the matcher infrastructure how many subexpressions there are we need +// to specialize `NumComponents`. +// +// template<> struct NumComponents<Frozzle*> { +// static constexpr size_t value = 3; +// }; +// +// And to tell the matcher infrastructure how to access those three +// subexpressions, we need to specialize `GetComponent` three times. +// +// template<> struct GetComponent<Frozzle*, 0> { +// Expression* operator()(Frozzle* curr) { return curr->foo; } +// }; +// template<> struct GetComponent<Frozzle*, 1> { +// Expression* operator()(Frozzle* curr) { return curr->bar; } +// }; +// template<> struct GetComponent<Frozzle*, 2> { +// Expression* operator()(Frozzle* curr) { return curr->baz; } +// }; +// +// For simple expressions, that's all we need to do to get a fully functional +// matcher that we can construct and use like this, where S1, S2, and S3 are +// the types of the submatchers to use and s1, s2, and s3 are instances of +// those types: +// +// Frozzle* extracted; +// auto matcher = Matcher<Frozzle*, S1, S2, S3>(&extracted, {}, s1, s2, s3); +// if (matches(expr, matcher)) { +// // `extracted` set to `expr` here +// } +// +// It's annoying to have to write out the types S1, S2, and S3 and we don't get +// class template argument deduction (CTAD) until C++17, so it's useful to +// create a wrapper function so can take advantage of function template +// argument deduction. We can also take this opportunity to make the interface +// more compact. +// +// template<class S1, class S2, class S3> +// inline decltype(auto) frozzle(Frozzle** binder, +// S1&& s1, S2&& s2, S3&& s3) { +// return Matcher<Frozzle*, S1, S2, S3>(binder, {}, s1, s2, s3); +// } +// template<class S1, class S2, class S3> +// inline decltype(auto) frozzle(S1&& s1, S2&& s2, S3&& s3) { +// return Matcher<Frozzle*, S1, S2, S3>(nullptr, {}, s1, s2, s3); +// } +// +// Notice that we make the interface more compact by providing overloads with +// and without the binder. Here is the final matcher usage: +// +// Frozzle* extracted; +// if (matches(expr, frozzle(&extracted, s1, s2, s3))) { +// // `extracted` set to `expr` here +// } +// +// Some matchers are more complicated, though, because they need to do +// something besides just applying submatchers to the components of an +// expression. These matchers require slightly more work. +// +// +// Complex matchers: +// +// Lets add a matcher that will match calls to functions whose names start with +// certain prefixes. Since this is not a normal matcher for Call expressions, +// we can't identify it by the Call* type. Instead, we have to create a new +// identifier type, called a "Kind" for it. +// +// struct PrefixCallKind {}; +// +// Next, since we're not in the common case of using a specific expression +// pointer as our kind, we have to tell the matcher infrastructure what type of +// thing this matcher matches. Since we want this matcher to be able to match +// any given prefix, we also need the matcher to contain the given prefix as +// state, and we need to tell the matcher infrastructure what type that state +// is as well. To specify these types, we need to specialize +// `KindTypeRegistry` for `PrefixCallKind`. +// +// template<> struct KindTypeRegistry<PrefixCallKind> { +// using matched_t = Call*; +// using data_t = Name; +// }; +// +// Note that because `matched_t` is set to a specific expression pointer, this +// matcher will automatically be able to be applied to any `Expression*`, not +// just `Call*`. If `matched_t` were not a specific expression pointer, this +// matcher would only be able to be applied to types compatible with +// `matched_t`. Also note that if a matcher does not need to store any state, +// its `data_t` should be set to `unused_t`. +// +// Now we need to tell the matcher infrastructure what custom logic to apply +// for this matcher. We do this by specializing `MatchSelf`. +// +// template<> struct MatchSelf<PrefixCallKind> { +// bool operator()(Call* curr, Name prefix) { +// return curr->name.startsWith(prefix); +// } +// }; +// +// Note that the first parameter to `MatchSelf<Kind>::operator()` will be that +// kind's `matched_t` and the second parameter will be that kind's `data_t`, +// which may be `unused_t`. (TODO: detect if `data_t` is `unused_t` and don't +// expose it in the Matcher interface if so.) +// +// After this, everything is the same as in the simple matcher case. This +// particular matcher doesn't need to recurse into any subcomponents, so we can +// skip straight to creating the wrapper function. +// +// decltype(auto) prefixCall(Call** binder, Name prefix) { +// return Matcher<PrefixCallKind>(binder, prefix); +// } +// +// Now we can use the new matcher: +// +// Call* call; +// if (matches(expr, prefixCall(&call, "__foo"))) { +// // `call` set to `expr` here +// } +// + +// The main entrypoint for matching. If the match succeeds, all variables bound +// in the matcher will be set to their corresponding matched values. Otherwise, +// the value of the bound variables is unspecified and may have changed. +template<class Matcher> inline bool matches(Expression* expr, Matcher matcher) { + return matcher.matches(expr); +} + +namespace Internal { + +struct unused_t {}; + +// Each matcher has a `Kind`, which controls how candidate values are +// destructured and inspected. For most matchers, `Kind` is a pointer to the +// matched subtype of Expression, but when there are multiple matchers for the +// same kind of expression, they are disambiguated by having different `Kind`s. +// In this case, or if the matcher matches something besides a pointer to a +// subtype of Expression, or if the matcher requires additional state, the +// matched type and the type of additional state must be associated with the +// `Kind` via a specialization of `KindTypeRegistry`. +template<class Kind> struct KindTypeRegistry { + // The matched type + using matched_t = void; + // The type of additional state needed to perform a match. Can be set to + // `unused_t` if it's not needed. + using data_t = unused_t; +}; + +// Given a `Kind`, produce the type `matched_t` that is matched by that Kind and +// the type `candidate_t` that is the type of the parameter of the `matches` +// method. These types are only different if `matched_t` is a pointer to a +// subtype of Expression, in which case `candidate_t` is Expression*. +template<class Kind> struct MatchTypes { + using matched_t = typename std::conditional_t< + std::is_base_of<Expression, std::remove_pointer_t<Kind>>::value, + Kind, + typename KindTypeRegistry<Kind>::matched_t>; + + static constexpr bool isExpr = + std::is_base_of<Expression, std::remove_pointer_t<matched_t>>::value; + + using candidate_t = + typename std::conditional_t<isExpr, Expression*, matched_t>; +}; + +template<class Kind> using matched_t = typename MatchTypes<Kind>::matched_t; +template<class Kind> using candidate_t = typename MatchTypes<Kind>::candidate_t; +template<class Kind> using data_t = typename KindTypeRegistry<Kind>::data_t; + +// Defined if the matched type is a specific expression pointer, so can be +// `dynCast`ed to from Expression*. +template<class Kind> +using enable_if_castable_t = typename std::enable_if< + std::is_base_of<Expression, std::remove_pointer_t<matched_t<Kind>>>::value && + !std::is_same<Expression*, matched_t<Kind>>::value, + int>::type; + +// Opposite of above +template<class Kind> +using enable_if_not_castable_t = typename std::enable_if< + !std::is_base_of<Expression, std::remove_pointer_t<matched_t<Kind>>>::value || + std::is_same<Expression*, matched_t<Kind>>::value, + int>::type; + +// Do a normal dynCast from Expression* to the subtype, storing the result in +// `out` and returning `true` iff the cast succeeded. +template<class Kind, enable_if_castable_t<Kind> = 0> +inline bool dynCastCandidate(candidate_t<Kind> candidate, + matched_t<Kind>& out) { + out = candidate->template dynCast<std::remove_pointer_t<matched_t<Kind>>>(); + return out != nullptr; +} + +// Otherwise we are not matching an Expression, so this is infallible. +template<class Kind, enable_if_not_castable_t<Kind> = 0> +inline bool dynCastCandidate(candidate_t<Kind> candidate, + matched_t<Kind>& out) { + out = candidate; + return true; +} + +// Matchers can optionally specialize this to perform custom matching logic +// before recursing into submatchers, potentially short-circuiting the match. +// Uses a struct because partial specialization of functions is not allowed. +template<class Kind> struct MatchSelf { + bool operator()(matched_t<Kind>, data_t<Kind>) { return true; } +}; + +// Used to statically ensure that each matcher has the correct number of +// submatchers. This needs to be specialized for each kind of matcher that has +// submatchers. +template<class Kind> struct NumComponents { + static constexpr size_t value = 0; +}; + +// Every kind of matcher needs to partially specialize this for each of its +// components. Each specialization should define +// +// T operator()(matched_t<Kind>) +// +// where T is the component's type. Components will be matched from first to +// last. Uses a struct instead of a function because partial specialization of +// functions is not allowed. +template<class Kind, int pos> struct GetComponent; + +// A type-level linked list to hold an arbitrary number of matchers. +template<class...> struct SubMatchers {}; +template<class CurrMatcher, class... NextMatchers> +struct SubMatchers<CurrMatcher, NextMatchers...> { + CurrMatcher curr; + SubMatchers<NextMatchers...> next; + SubMatchers(CurrMatcher curr, NextMatchers... next) + : curr(curr), next(next...){}; +}; + +// Iterates through the components of the candidate, applying a submatcher to +// each component. Uses a struct instead of a function because partial +// specialization of functions is not allowed. +template<class Kind, int pos, class CurrMatcher = void, class... NextMatchers> +struct Components { + static inline bool + match(matched_t<Kind> candidate, + SubMatchers<CurrMatcher, NextMatchers...>& matchers) { + return matchers.curr.matches(GetComponent<Kind, pos>{}(candidate)) && + Components<Kind, pos + 1, NextMatchers...>::match(candidate, + matchers.next); + } +}; +template<class Kind, int pos> struct Components<Kind, pos> { + static_assert(pos == NumComponents<Kind>::value, + "Unexpected number of submatchers"); + static inline bool match(matched_t<Kind>, SubMatchers<>) { + // Base case when there are no components left; trivially true. + return true; + } +}; + +template<class Kind, class... Matchers> struct Matcher { + matched_t<Kind>* binder; + data_t<Kind> data; + SubMatchers<Matchers...> submatchers; + + Matcher(matched_t<Kind>* binder, data_t<Kind> data, Matchers... submatchers) + : binder(binder), data(data), submatchers(submatchers...) {} + + inline bool matches(candidate_t<Kind> candidate) { + matched_t<Kind> casted; + if (dynCastCandidate<Kind>(candidate, casted)) { + if (binder != nullptr) { + *binder = casted; + } + return MatchSelf<Kind>{}(casted, data) && + Components<Kind, 0, Matchers...>::match(casted, submatchers); + } + return false; + } +}; + +// Concrete low-level matcher implementations. Not intended for direct external +// use. + +// Any<T>: matches any value of the expected type +template<class T> struct AnyKind {}; +template<class T> struct KindTypeRegistry<AnyKind<T>> { + using matched_t = T; + using data_t = unused_t; +}; +template<class T> inline decltype(auto) Any(T* binder) { + return Matcher<AnyKind<T>>(binder, {}); +} + +// Exact<T>: matches exact values of the expected type +template<class T> struct ExactKind {}; +template<class T> struct KindTypeRegistry<ExactKind<T>> { + using matched_t = T; + using data_t = T; +}; +template<class T> struct MatchSelf<ExactKind<T>> { + bool operator()(T self, T expected) { return self == expected; } +}; +template<class T> inline decltype(auto) Exact(T* binder, T data) { + return Matcher<ExactKind<T>>(binder, data); +} + +// {I32,I64,Int,F32,F64,Float,Number}Lit: match `Literal` of the expected `Type` +struct I32LK { + static bool matchType(Literal lit) { return lit.type == Type::i32; } + static int32_t getVal(Literal lit) { return lit.geti32(); } +}; +struct I64LK { + static bool matchType(Literal lit) { return lit.type == Type::i64; } + static int64_t getVal(Literal lit) { return lit.geti64(); } +}; +struct IntLK { + static bool matchType(Literal lit) { return lit.type.isInteger(); } + static int64_t getVal(Literal lit) { return lit.getInteger(); } +}; +struct F32LK { + static bool matchType(Literal lit) { return lit.type == Type::f32; } + static float getVal(Literal lit) { return lit.getf32(); } +}; +struct F64LK { + static bool matchType(Literal lit) { return lit.type == Type::f64; } + static double getVal(Literal lit) { return lit.getf64(); } +}; +struct FloatLK { + static bool matchType(Literal lit) { return lit.type.isFloat(); } + static double getVal(Literal lit) { return lit.getFloat(); } +}; +template<class T> struct LitKind {}; +template<class T> struct KindTypeRegistry<LitKind<T>> { + using matched_t = Literal; + using data_t = unused_t; +}; +template<class T> struct MatchSelf<LitKind<T>> { + bool operator()(Literal lit, unused_t) { return T::matchType(lit); } +}; +template<class T> struct NumComponents<LitKind<T>> { + static constexpr size_t value = 1; +}; +template<class T> struct GetComponent<LitKind<T>, 0> { + decltype(auto) operator()(Literal lit) { return T::getVal(lit); } +}; +template<class S> inline decltype(auto) I32Lit(Literal* binder, S&& s) { + return Matcher<LitKind<I32LK>, S>(binder, {}, s); +} +template<class S> inline decltype(auto) I64Lit(Literal* binder, S&& s) { + return Matcher<LitKind<I64LK>, S>(binder, {}, s); +} +template<class S> inline decltype(auto) IntLit(Literal* binder, S&& s) { + return Matcher<LitKind<IntLK>, S>(binder, {}, s); +} +template<class S> inline decltype(auto) F32Lit(Literal* binder, S&& s) { + return Matcher<LitKind<F32LK>, S>(binder, {}, s); +} +template<class S> inline decltype(auto) F64Lit(Literal* binder, S&& s) { + return Matcher<LitKind<F64LK>, S>(binder, {}, s); +} +template<class S> inline decltype(auto) FloatLit(Literal* binder, S&& s) { + return Matcher<LitKind<FloatLK>, S>(binder, {}, s); +} +struct NumberLitKind {}; +template<> struct KindTypeRegistry<NumberLitKind> { + using matched_t = Literal; + using data_t = int32_t; +}; +template<> struct MatchSelf<NumberLitKind> { + bool operator()(Literal lit, int32_t expected) { + return lit.type.isNumber() && + Literal::makeFromInt32(expected, lit.type) == lit; + } +}; +inline decltype(auto) NumberLit(Literal* binder, int32_t expected) { + return Matcher<NumberLitKind>(binder, expected); +} + +// Const +template<> struct NumComponents<Const*> { static constexpr size_t value = 1; }; +template<> struct GetComponent<Const*, 0> { + Literal operator()(Const* c) { return c->value; } +}; +template<class S> inline decltype(auto) ConstMatcher(Const** binder, S&& s) { + return Matcher<Const*, S>(binder, {}, s); +} + +// Unary and AbstractUnary +struct UnaryK { + using Op = UnaryOp; + static UnaryOp getOp(Type, Op op) { return op; } +}; +struct AbstractUnaryK { + using Op = Abstract::Op; + static UnaryOp getOp(Type type, Abstract::Op op) { + return Abstract::getUnary(type, op); + } +}; +template<class T> struct UnaryKind {}; +template<class T> struct KindTypeRegistry<UnaryKind<T>> { + using matched_t = Unary*; + using data_t = typename T::Op; +}; +template<class T> struct MatchSelf<UnaryKind<T>> { + bool operator()(Unary* curr, typename T::Op op) { + return curr->op == T::getOp(curr->value->type, op); + } +}; +template<class T> struct NumComponents<UnaryKind<T>> { + static constexpr size_t value = 1; +}; +template<class T> struct GetComponent<UnaryKind<T>, 0> { + Expression* operator()(Unary* curr) { return curr->value; } +}; +template<class S> +inline decltype(auto) UnaryMatcher(Unary** binder, UnaryOp op, S&& s) { + return Matcher<UnaryKind<UnaryK>, S>(binder, op, s); +} +template<class S> +inline decltype(auto) +AbstractUnaryMatcher(Unary** binder, Abstract::Op op, S&& s) { + return Matcher<UnaryKind<AbstractUnaryK>, S>(binder, op, s); +} + +// Binary and AbstractBinary +struct BinaryK { + using Op = BinaryOp; + static BinaryOp getOp(Type, Op op) { return op; } +}; +struct AbstractBinaryK { + using Op = Abstract::Op; + static BinaryOp getOp(Type type, Abstract::Op op) { + return Abstract::getBinary(type, op); + } +}; +template<class T> struct BinaryKind {}; +template<class T> struct KindTypeRegistry<BinaryKind<T>> { + using matched_t = Binary*; + using data_t = typename T::Op; +}; +template<class T> struct MatchSelf<BinaryKind<T>> { + bool operator()(Binary* curr, typename T::Op op) { + return curr->op == T::getOp(curr->left->type, op); + } +}; +template<class T> struct NumComponents<BinaryKind<T>> { + static constexpr size_t value = 2; +}; +template<class T> struct GetComponent<BinaryKind<T>, 0> { + Expression* operator()(Binary* curr) { return curr->left; } +}; +template<class T> struct GetComponent<BinaryKind<T>, 1> { + Expression* operator()(Binary* curr) { return curr->right; } +}; +template<class S1, class S2> +inline decltype(auto) +BinaryMatcher(Binary** binder, BinaryOp op, S1&& s1, S2&& s2) { + return Matcher<BinaryKind<BinaryK>, S1, S2>(binder, op, s1, s2); +} +template<class S1, class S2> +inline decltype(auto) +AbstractBinaryMatcher(Binary** binder, Abstract::Op op, S1&& s1, S2&& s2) { + return Matcher<BinaryKind<AbstractBinaryK>, S1, S2>(binder, op, s1, s2); +} + +// Select +template<> struct NumComponents<Select*> { static constexpr size_t value = 3; }; +template<> struct GetComponent<Select*, 0> { + Expression* operator()(Select* curr) { return curr->ifTrue; } +}; +template<> struct GetComponent<Select*, 1> { + Expression* operator()(Select* curr) { return curr->ifFalse; } +}; +template<> struct GetComponent<Select*, 2> { + Expression* operator()(Select* curr) { return curr->condition; } +}; +template<class S1, class S2, class S3> +inline decltype(auto) +SelectMatcher(Select** binder, S1&& s1, S2&& s2, S3&& s3) { + return Matcher<Select*, S1, S2, S3>(binder, {}, s1, s2, s3); +} + +} // namespace Internal + +// Public matching API + +inline decltype(auto) i32() { + return Internal::ConstMatcher( + nullptr, Internal::I32Lit(nullptr, Internal::Any<int32_t>(nullptr))); +} +// Use int rather than int32_t to disambiguate literal 0, which otherwise could +// be resolved to either the int32_t overload or any of the pointer overloads. +inline decltype(auto) i32(int x) { + return Internal::ConstMatcher( + nullptr, Internal::I32Lit(nullptr, Internal::Exact<int32_t>(nullptr, x))); +} +inline decltype(auto) i32(int32_t* binder) { + return Internal::ConstMatcher( + nullptr, Internal::I32Lit(nullptr, Internal::Any(binder))); +} +inline decltype(auto) i32(Literal* binder) { + return Internal::ConstMatcher( + nullptr, Internal::I32Lit(binder, Internal::Any<int32_t>(nullptr))); +} +inline decltype(auto) i32(Const** binder) { + return Internal::ConstMatcher( + binder, Internal::I32Lit(nullptr, Internal::Any<int32_t>(nullptr))); +} + +inline decltype(auto) i64() { + return Internal::ConstMatcher( + nullptr, Internal::I64Lit(nullptr, Internal::Any<int64_t>(nullptr))); +} +inline decltype(auto) i64(int64_t x) { + return Internal::ConstMatcher( + nullptr, Internal::I64Lit(nullptr, Internal::Exact<int64_t>(nullptr, x))); +} +// Disambiguate literal 0, which could otherwise be interpreted as a pointer +inline decltype(auto) i64(int x) { return i64(int64_t(x)); } +inline decltype(auto) i64(int64_t* binder) { + return Internal::ConstMatcher( + nullptr, Internal::I64Lit(nullptr, Internal::Any(binder))); +} +inline decltype(auto) i64(Literal* binder) { + return Internal::ConstMatcher( + nullptr, Internal::I64Lit(binder, Internal::Any<int64_t>(nullptr))); +} +inline decltype(auto) i64(Const** binder) { + return Internal::ConstMatcher( + binder, Internal::I64Lit(nullptr, Internal::Any<int64_t>(nullptr))); +} + +inline decltype(auto) f32() { + return Internal::ConstMatcher( + nullptr, Internal::F32Lit(nullptr, Internal::Any<float>(nullptr))); +} +inline decltype(auto) f32(float x) { + return Internal::ConstMatcher( + nullptr, Internal::F32Lit(nullptr, Internal::Exact<float>(nullptr, x))); +} +// Disambiguate literal 0, which could otherwise be interpreted as a pointer +inline decltype(auto) f32(int x) { return f32(float(x)); } +inline decltype(auto) f32(float* binder) { + return Internal::ConstMatcher( + nullptr, Internal::F32Lit(nullptr, Internal::Any(binder))); +} +inline decltype(auto) f32(Literal* binder) { + return Internal::ConstMatcher( + nullptr, Internal::F32Lit(binder, Internal::Any<float>(nullptr))); +} +inline decltype(auto) f32(Const** binder) { + return Internal::ConstMatcher( + binder, Internal::F32Lit(nullptr, Internal::Any<float>(nullptr))); +} + +inline decltype(auto) f64() { + return Internal::ConstMatcher( + nullptr, Internal::F64Lit(nullptr, Internal::Any<double>(nullptr))); +} +inline decltype(auto) f64(double x) { + return Internal::ConstMatcher( + nullptr, Internal::F64Lit(nullptr, Internal::Exact<double>(nullptr, x))); +} +// Disambiguate literal 0, which could otherwise be interpreted as a pointer +inline decltype(auto) f64(int x) { return f64(double(x)); } +inline decltype(auto) f64(double* binder) { + return Internal::ConstMatcher( + nullptr, Internal::F64Lit(nullptr, Internal::Any(binder))); +} +inline decltype(auto) f64(Literal* binder) { + return Internal::ConstMatcher( + nullptr, Internal::F64Lit(binder, Internal::Any<double>(nullptr))); +} +inline decltype(auto) f64(Const** binder) { + return Internal::ConstMatcher( + binder, Internal::F64Lit(nullptr, Internal::Any<double>(nullptr))); +} + +inline decltype(auto) ival() { + return Internal::ConstMatcher( + nullptr, Internal::IntLit(nullptr, Internal::Any<int64_t>(nullptr))); +} +inline decltype(auto) ival(int64_t x) { + return Internal::ConstMatcher( + nullptr, Internal::IntLit(nullptr, Internal::Exact<int64_t>(nullptr, x))); +} +// Disambiguate literal 0, which could otherwise be interpreted as a pointer +inline decltype(auto) ival(int x) { return ival(int64_t(x)); } +inline decltype(auto) ival(int64_t* binder) { + return Internal::ConstMatcher( + nullptr, Internal::IntLit(nullptr, Internal::Any(binder))); +} +inline decltype(auto) ival(Literal* binder) { + return Internal::ConstMatcher( + nullptr, Internal::IntLit(binder, Internal::Any<int64_t>(nullptr))); +} +inline decltype(auto) ival(Const** binder) { + return Internal::ConstMatcher( + binder, Internal::IntLit(nullptr, Internal::Any<int64_t>(nullptr))); +} +inline decltype(auto) ival(Literal* binder, int64_t x) { + return Internal::ConstMatcher( + nullptr, Internal::IntLit(binder, Internal::Exact<int64_t>(nullptr, x))); +} +inline decltype(auto) ival(Const** binder, int64_t x) { + return Internal::ConstMatcher( + binder, Internal::IntLit(nullptr, Internal::Exact<int64_t>(nullptr, x))); +} + +inline decltype(auto) fval() { + return Internal::ConstMatcher( + nullptr, Internal::FloatLit(nullptr, Internal::Any<double>(nullptr))); +} +inline decltype(auto) fval(double x) { + return Internal::ConstMatcher( + nullptr, Internal::FloatLit(nullptr, Internal::Exact<double>(nullptr, x))); +} +// Disambiguate literal 0, which could otherwise be interpreted as a pointer +inline decltype(auto) fval(int x) { return fval(double(x)); } +inline decltype(auto) fval(double* binder) { + return Internal::ConstMatcher( + nullptr, Internal::FloatLit(nullptr, Internal::Any(binder))); +} +inline decltype(auto) fval(Literal* binder) { + return Internal::ConstMatcher( + nullptr, Internal::FloatLit(binder, Internal::Any<double>(nullptr))); +} +inline decltype(auto) fval(Const** binder) { + return Internal::ConstMatcher( + binder, Internal::FloatLit(nullptr, Internal::Any<double>(nullptr))); +} +inline decltype(auto) fval(Literal* binder, double x) { + return Internal::ConstMatcher( + nullptr, Internal::FloatLit(binder, Internal::Exact<double>(nullptr, x))); +} +inline decltype(auto) fval(Const** binder, double x) { + return Internal::ConstMatcher( + binder, Internal::FloatLit(nullptr, Internal::Exact<double>(nullptr, x))); +} + +inline decltype(auto) constant() { + return Internal::ConstMatcher(nullptr, Internal::Any<Literal>(nullptr)); +} +inline decltype(auto) constant(int x) { + return Internal::ConstMatcher(nullptr, Internal::NumberLit(nullptr, x)); +} +inline decltype(auto) constant(Literal* binder) { + return Internal::ConstMatcher(nullptr, Internal::Any(binder)); +} +inline decltype(auto) constant(Const** binder) { + return Internal::ConstMatcher(binder, Internal::Any<Literal>(nullptr)); +} +inline decltype(auto) constant(Literal* binder, int32_t x) { + return Internal::ConstMatcher(nullptr, Internal::NumberLit(binder, x)); +} +inline decltype(auto) constant(Const** binder, int32_t x) { + return Internal::ConstMatcher(binder, Internal::NumberLit(nullptr, x)); +} + +inline decltype(auto) any() { return Internal::Any<Expression*>(nullptr); } +inline decltype(auto) any(Expression** binder) { return Internal::Any(binder); } + +template<class S> inline decltype(auto) unary(UnaryOp op, S&& s) { + return Internal::UnaryMatcher(nullptr, op, s); +} +template<class S> inline decltype(auto) unary(Abstract::Op op, S&& s) { + return Internal::AbstractUnaryMatcher(nullptr, op, s); +} +template<class S> +inline decltype(auto) unary(Unary** binder, UnaryOp op, S&& s) { + return Internal::UnaryMatcher(binder, op, s); +} +template<class S> +inline decltype(auto) unary(Unary** binder, Abstract::Op op, S&& s) { + return Internal::AbstractUnaryMatcher(binder, op, s); +} + +template<class S1, class S2> +inline decltype(auto) binary(BinaryOp op, S1&& s1, S2&& s2) { + return Internal::BinaryMatcher(nullptr, op, s1, s2); +} +template<class S1, class S2> +inline decltype(auto) binary(Abstract::Op op, S1&& s1, S2&& s2) { + return Internal::AbstractBinaryMatcher(nullptr, op, s1, s2); +} +template<class S1, class S2> +inline decltype(auto) binary(Binary** binder, BinaryOp op, S1&& s1, S2&& s2) { + return Internal::BinaryMatcher(binder, op, s1, s2); +} +template<class S1, class S2> +inline decltype(auto) +binary(Binary** binder, Abstract::Op op, S1&& s1, S2&& s2) { + return Internal::AbstractBinaryMatcher(binder, op, s1, s2); +} + +template<class S1, class S2, class S3> +inline decltype(auto) select(S1&& s1, S2&& s2, S3&& s3) { + return Internal::SelectMatcher(nullptr, s1, s2, s3); +} +template<class S1, class S2, class S3> +inline decltype(auto) select(Select** binder, S1&& s1, S2&& s2, S3&& s3) { + return Internal::SelectMatcher(binder, s1, s2, s3); +} + +} // namespace Match + +} // namespace wasm + +#endif // wasm_ir_match_h diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index 9cc16004d..1c2445d5d 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -28,6 +28,7 @@ #include <ir/literal-utils.h> #include <ir/load-utils.h> #include <ir/manipulation.h> +#include <ir/match.h> #include <ir/properties.h> #include <ir/utils.h> #include <pass.h> @@ -132,6 +133,19 @@ struct LocalScanner : PostWalker<LocalScanner> { } }; +// Create a custom matcher for checking side effects +template<class Opt> struct PureMatcherKind {}; +template<class Opt> +struct Match::Internal::KindTypeRegistry<PureMatcherKind<Opt>> { + using matched_t = Expression*; + using data_t = Opt*; +}; +template<class Opt> struct Match::Internal::MatchSelf<PureMatcherKind<Opt>> { + bool operator()(Expression* curr, Opt* opt) { + return !opt->effects(curr).hasSideEffects(); + } +}; + // Main pass class struct OptimizeInstructions : public WalkerPass< @@ -189,6 +203,20 @@ struct OptimizeInstructions } } + EffectAnalyzer effects(Expression* expr) { + return EffectAnalyzer(getPassOptions(), getModule()->features, expr); + } + + decltype(auto) pure(Expression** binder) { + using namespace Match::Internal; + return Matcher<PureMatcherKind<OptimizeInstructions>>(binder, this); + } + + bool canReorder(Expression* a, Expression* b) { + return EffectAnalyzer::canReorder( + getPassOptions(), getModule()->features, a, b); + } + // Optimizations that don't yet fit in the pattern DSL, but could be // eventually maybe Expression* handOptimize(Expression* curr) { @@ -205,6 +233,104 @@ struct OptimizeInstructions if (isSymmetric(binary)) { canonicalize(binary); } + } + + { + // TODO: It is an ongoing project to port more transformations to the + // match API. Once most of the transformations have been ported, the + // `using namespace Match` can be hoisted to function scope and this extra + // block scope can be removed. + using namespace Match; + Builder builder(*getModule()); + { + // X == 0 => eqz X + Expression* x; + if (matches(curr, binary(EqInt32, any(&x), i32(0)))) { + return Builder(*getModule()).makeUnary(EqZInt32, x); + } + } + { + // try to get rid of (0 - ..), that is, a zero only used to negate an + // int. an add of a subtract can be flipped in order to remove it: + // (i32.add + // (i32.sub + // (i32.const 0) + // X + // ) + // Y + // ) + // => + // (i32.sub + // Y + // X + // ) + // Note that this reorders X and Y, so we need to be careful about that. + Expression *x, *y; + Binary* sub; + if (matches(curr, + binary(AddInt32, + binary(&sub, SubInt32, i32(0), any(&x)), + any(&y))) && + canReorder(x, y)) { + sub->left = y; + sub->right = x; + return sub; + } + } + { + // The flip case is even easier, as no reordering occurs: + // (i32.add + // Y + // (i32.sub + // (i32.const 0) + // X + // ) + // ) + // => + // (i32.sub + // Y + // X + // ) + Expression* y; + Binary* sub; + if (matches(curr, + binary(AddInt32, + any(&y), + binary(&sub, SubInt32, i32(0), any())))) { + sub->left = y; + return sub; + } + } + { + // try de-morgan's AND law, + // (eqz X) and (eqz Y) === eqz (X or Y) + // Note that the OR and XOR laws do not work here, as these + // are not booleans (we could check if they are, but a boolean + // would already optimize with the eqz anyhow, unless propagating). + // But for AND, the left is true iff X and Y are each all zero bits, + // and the right is true if the union of their bits is zero; same. + Unary* un; + Binary* bin; + Expression *x, *y; + if (matches(curr, + binary(&bin, + AndInt32, + unary(&un, EqZInt32, any(&x)), + unary(EqZInt32, any(&y))))) { + bin->op = OrInt32; + bin->left = x; + bin->right = y; + un->value = bin; + return un; + } + } + } + + if (auto* select = curr->dynCast<Select>()) { + return optimizeSelect(select); + } + + if (auto* binary = curr->dynCast<Binary>()) { if (auto* ext = Properties::getAlmostSignExt(binary)) { Index extraShifts; auto bits = Properties::getAlmostSignExtBits(binary, extraShifts); @@ -313,57 +439,6 @@ struct OptimizeInstructions // note that both left and right may be consts, but then we let // precompute compute the constant result } else if (binary->op == AddInt32) { - // try to get rid of (0 - ..), that is, a zero only used to negate an - // int. an add of a subtract can be flipped in order to remove it: - // (i32.add - // (i32.sub - // (i32.const 0) - // X - // ) - // Y - // ) - // => - // (i32.sub - // Y - // X - // ) - // Note that this reorders X and Y, so we need to be careful about that. - if (auto* sub = binary->left->dynCast<Binary>()) { - if (sub->op == SubInt32) { - if (auto* subZero = sub->left->dynCast<Const>()) { - if (subZero->value.geti32() == 0) { - if (EffectAnalyzer::canReorder( - getPassOptions(), features, sub->right, binary->right)) { - sub->left = binary->right; - return sub; - } - } - } - } - } - // The flip case is even easier, as no reordering occurs: - // (i32.add - // Y - // (i32.sub - // (i32.const 0) - // X - // ) - // ) - // => - // (i32.sub - // Y - // X - // ) - if (auto* sub = binary->right->dynCast<Binary>()) { - if (sub->op == SubInt32) { - if (auto* subZero = sub->left->dynCast<Const>()) { - if (subZero->value.geti32() == 0) { - sub->left = binary->left; - return sub; - } - } - } - } if (auto* ret = optimizeAddedConstants(binary)) { return ret; } @@ -463,30 +538,6 @@ struct OptimizeInstructions } } // bitwise operations - if (binary->op == AndInt32) { - // try de-morgan's AND law, - // (eqz X) and (eqz Y) === eqz (X or Y) - // Note that the OR and XOR laws do not work here, as these - // are not booleans (we could check if they are, but a boolean - // would already optimize with the eqz anyhow, unless propagating). - // But for AND, the left is true iff X and Y are each all zero bits, - // and the right is true if the union of their bits is zero; same. - if (auto* left = binary->left->dynCast<Unary>()) { - if (left->op == EqZInt32) { - if (auto* right = binary->right->dynCast<Unary>()) { - if (right->op == EqZInt32) { - // reuse one unary, drop the other - auto* leftValue = left->value; - left->value = binary; - binary->left = leftValue; - binary->right = right->value; - binary->op = OrInt32; - return left; - } - } - } - } - } // for and and or, we can potentially conditionalize if (binary->op == AndInt32 || binary->op == OrInt32) { if (auto* ret = conditionalizeExpensiveOnBitwise(binary)) { @@ -507,8 +558,7 @@ struct OptimizeInstructions } // finally, try more expensive operations on the binary in // the case that they have no side effects - if (!EffectAnalyzer(getPassOptions(), features, binary->left) - .hasSideEffects()) { + if (!effects(binary->left).hasSideEffects()) { if (ExpressionAnalyzer::equal(binary->left, binary->right)) { if (auto* ret = optimizeBinaryWithEqualEffectlessChildren(binary)) { return ret; @@ -563,9 +613,7 @@ struct OptimizeInstructions // sides are identical, fold // if we can replace the if with one arm, and no side effects in the // condition, do that - auto needCondition = - EffectAnalyzer(getPassOptions(), features, iff->condition) - .hasSideEffects(); + auto needCondition = effects(iff->condition).hasSideEffects(); auto isSubType = Type::isSubType(iff->ifTrue->type, iff->type); if (isSubType && !needCondition) { return iff->ifTrue; @@ -591,92 +639,6 @@ struct OptimizeInstructions } } } - } else if (auto* select = curr->dynCast<Select>()) { - select->condition = optimizeBoolean(select->condition); - if (auto* c = select->condition->dynCast<Const>()) { - // constant condition, we can just pick the right side (barring side - // effects) - if (c->value.getInteger()) { - if (!EffectAnalyzer(getPassOptions(), features, select->ifFalse) - .hasSideEffects()) { - return select->ifTrue; - } else { - // don't bother - we would need to reverse the order using a temp - // local, which is bad - } - } else { - if (!EffectAnalyzer(getPassOptions(), features, select->ifTrue) - .hasSideEffects()) { - return select->ifFalse; - } else { - Builder builder(*getModule()); - return builder.makeSequence(builder.makeDrop(select->ifTrue), - select->ifFalse); - } - } - } - if (auto* constTrue = select->ifTrue->dynCast<Const>()) { - if (auto* constFalse = select->ifFalse->dynCast<Const>()) { - if (select->type == Type::i32 || select->type == Type::i64) { - auto trueValue = constTrue->value.getInteger(); - auto falseValue = constFalse->value.getInteger(); - if ((trueValue == 1LL && falseValue == 0LL) || - (trueValue == 0LL && falseValue == 1LL)) { - Builder builder(*getModule()); - Expression* condition = select->condition; - if (trueValue == 0LL) { - condition = - optimizeBoolean(builder.makeUnary(EqZInt32, condition)); - } - if (!Properties::emitsBoolean(condition)) { - // expr ? 1 : 0 ==> !!expr - condition = builder.makeUnary( - EqZInt32, builder.makeUnary(EqZInt32, condition)); - } - return select->type == Type::i64 - ? builder.makeUnary(ExtendUInt32, condition) - : condition; - } - } - } - } - if (auto* condition = select->condition->dynCast<Unary>()) { - if (condition->op == EqZInt32) { - // flip select to remove eqz, if we can reorder - EffectAnalyzer ifTrue(getPassOptions(), features, select->ifTrue); - EffectAnalyzer ifFalse(getPassOptions(), features, select->ifFalse); - if (!ifTrue.invalidates(ifFalse)) { - select->condition = condition->value; - std::swap(select->ifTrue, select->ifFalse); - } - } - } - if (ExpressionAnalyzer::equal(select->ifTrue, select->ifFalse)) { - // sides are identical, fold - EffectAnalyzer value(getPassOptions(), features, select->ifTrue); - if (value.hasSideEffects()) { - // at best we don't need the condition, but need to execute the value - // twice. a block is larger than a select by 2 bytes, and - // we must drop one value, so 3, while we save the condition, - // so it's not clear this is worth it, TODO - } else { - // value has no side effects - EffectAnalyzer condition( - getPassOptions(), features, select->condition); - if (!condition.hasSideEffects()) { - return select->ifTrue; - } else { - // the condition is last, so we need a new local, and it may be - // a bad idea to use a block like we do for an if. Do it only if we - // can reorder - if (!condition.invalidates(value)) { - Builder builder(*getModule()); - return builder.makeSequence(builder.makeDrop(select->condition), - select->ifTrue); - } - } - } - } } else if (auto* br = curr->dynCast<Break>()) { if (br->condition) { br->condition = optimizeBoolean(br->condition); @@ -734,15 +696,12 @@ private: // write more concise pattern matching code elsewhere. void canonicalize(Binary* binary) { assert(isSymmetric(binary)); - FeatureSet features = getModule()->features; auto swap = [&]() { - assert(EffectAnalyzer::canReorder( - getPassOptions(), features, binary->left, binary->right)); + assert(canReorder(binary->left, binary->right)); std::swap(binary->left, binary->right); }; auto maybeSwap = [&]() { - if (EffectAnalyzer::canReorder( - getPassOptions(), features, binary->left, binary->right)) { + if (canReorder(binary->left, binary->right)) { swap(); } }; @@ -860,6 +819,89 @@ private: return boolean; } + Expression* optimizeSelect(Select* curr) { + using namespace Match; + Builder builder(*getModule()); + curr->condition = optimizeBoolean(curr->condition); + { + // Constant condition, we can just pick the correct side (barring side + // effects) + Expression *ifTrue, *ifFalse; + if (matches(curr, select(pure(&ifTrue), any(&ifFalse), i32(0)))) { + return ifFalse; + } + if (matches(curr, select(any(&ifTrue), any(&ifFalse), i32(0)))) { + return builder.makeSequence(builder.makeDrop(ifTrue), ifFalse); + } + int32_t cond; + if (matches(curr, select(any(&ifTrue), pure(&ifFalse), i32(&cond)))) { + // The condition must be non-zero because a zero would have matched one + // of the previous patterns. + assert(cond != 0); + return ifTrue; + } + // Don't bother when `ifFalse` isn't pure - we would need to reverse the + // order using a temp local, which would be bad + } + { + // Flip select to remove eqz if we can reorder + Select* s; + Expression *ifTrue, *ifFalse, *c; + if (matches( + curr, + select( + &s, any(&ifTrue), any(&ifFalse), unary(EqZInt32, any(&c)))) && + canReorder(ifTrue, ifFalse)) { + s->ifTrue = ifFalse; + s->ifFalse = ifTrue; + s->condition = c; + } + } + { + // Simplify selects between 0 and 1 + Expression* c; + bool reversed = matches(curr, select(ival(0), ival(1), any(&c))); + if (reversed || matches(curr, select(ival(1), ival(0), any(&c)))) { + if (reversed) { + c = optimizeBoolean(builder.makeUnary(EqZInt32, c)); + } + if (!Properties::emitsBoolean(c)) { + // cond ? 1 : 0 ==> !!cond + c = builder.makeUnary(EqZInt32, builder.makeUnary(EqZInt32, c)); + } + return curr->type == Type::i64 ? builder.makeUnary(ExtendUInt32, c) : c; + } + } + { + // Sides are identical, fold + Expression *ifTrue, *ifFalse, *c; + if (matches(curr, select(any(&ifTrue), any(&ifFalse), any(&c))) && + ExpressionAnalyzer::equal(ifTrue, ifFalse)) { + auto value = effects(ifTrue); + if (value.hasSideEffects()) { + // At best we don't need the condition, but need to execute the + // value twice. a block is larger than a select by 2 bytes, and we + // must drop one value, so 3, while we save the condition, so it's + // not clear this is worth it, TODO + } else { + // value has no side effects + auto condition = effects(c); + if (!condition.hasSideEffects()) { + return ifTrue; + } else { + // The condition is last, so we need a new local, and it may be a + // bad idea to use a block like we do for an if. Do it only if we + // can reorder + if (!condition.invalidates(value)) { + return builder.makeSequence(builder.makeDrop(c), ifTrue); + } + } + } + } + } + return nullptr; + } + // find added constants in an expression tree, including multiplied/shifted, // and combine them note that we ignore division/shift-right, as rounding // makes this nonlinear, so not a valid opt @@ -1024,9 +1066,8 @@ private: if (!Properties::emitsBoolean(left) || !Properties::emitsBoolean(right)) { return nullptr; } - FeatureSet features = getModule()->features; - auto leftEffects = EffectAnalyzer(getPassOptions(), features, left); - auto rightEffects = EffectAnalyzer(getPassOptions(), features, right); + auto leftEffects = effects(left); + auto rightEffects = effects(right); auto leftHasSideEffects = leftEffects.hasSideEffects(); auto rightHasSideEffects = rightEffects.hasSideEffects(); if (leftHasSideEffects && rightHasSideEffects) { @@ -1072,16 +1113,13 @@ private: // (x > y) | (x == y) ==> x >= y Expression* combineOr(Binary* binary) { assert(binary->op == OrInt32); - FeatureSet features = getModule()->features; if (auto* left = binary->left->dynCast<Binary>()) { if (auto* right = binary->right->dynCast<Binary>()) { if (left->op != right->op && ExpressionAnalyzer::equal(left->left, right->left) && ExpressionAnalyzer::equal(left->right, right->right) && - !EffectAnalyzer(getPassOptions(), features, left->left) - .hasSideEffects() && - !EffectAnalyzer(getPassOptions(), features, left->right) - .hasSideEffects()) { + !effects(left->left).hasSideEffects() && + !effects(left->right).hasSideEffects()) { switch (left->op) { // (x > y) | (x == y) ==> x >= y case EqInt32: { @@ -1202,185 +1240,165 @@ private: // optimize trivial math operations, given that the right side of a binary // is a constant - // TODO: templatize on type? - Expression* optimizeWithConstantOnRight(Binary* binary) { - FeatureSet features = getModule()->features; - auto type = binary->right->type; - auto* right = binary->right->cast<Const>(); - if (type.isInteger()) { - auto constRight = right->value.getInteger(); - // operations on zero - if (constRight == 0LL) { - if (binary->op == Abstract::getBinary(type, Abstract::Shl) || - binary->op == Abstract::getBinary(type, Abstract::ShrU) || - binary->op == Abstract::getBinary(type, Abstract::ShrS) || - binary->op == Abstract::getBinary(type, Abstract::Or) || - binary->op == Abstract::getBinary(type, Abstract::Xor)) { - return binary->left; - } else if ((binary->op == Abstract::getBinary(type, Abstract::Mul) || - binary->op == Abstract::getBinary(type, Abstract::And)) && - !EffectAnalyzer(getPassOptions(), features, binary->left) - .hasSideEffects()) { - return binary->right; - } else if (binary->op == Abstract::getBinary(type, Abstract::Eq)) { - return Builder(*getModule()) - .makeUnary(Abstract::getUnary(type, Abstract::EqZ), binary->left); - } - } - // operations on one - if (constRight == 1LL) { - // (signed)x % 1 ==> 0 - if (binary->op == Abstract::getBinary(type, Abstract::RemS) && - !EffectAnalyzer(getPassOptions(), features, binary->left) - .hasSideEffects()) { - right->value = Literal::makeSingleZero(type); - return right; - } - // bool(x) | 1 ==> 1 - // bool(x) & 1 ==> bool(x) - // bool(x) == 1 ==> bool(x) - // bool(x) != 1 ==> !bool(x) - if (Bits::getMaxBits(binary->left, this) == 1) { - switch (binary->op) { - case OrInt32: - case OrInt64: { - if (!EffectAnalyzer(getPassOptions(), features, binary->left) - .hasSideEffects()) { - // bool(x) | 1 ==> 1 - return binary->right; - } - break; - } - case AndInt32: - case AndInt64: - case EqInt32: { - // bool(x) & 1 ==> bool(x) - // bool(x) == 1 ==> bool(x) - return binary->left; - } - case EqInt64: { - // i64(bool(x)) == 1 ==> i32(bool(x)) - return Builder(*getModule()).makeUnary(WrapInt64, binary->left); - } - case NeInt32: - case NeInt64: { - // bool(x) != 1 ==> !bool(x) - return Builder(*getModule()) - .makeUnary( - Abstract::getUnary(binary->left->type, Abstract::EqZ), - binary->left); - } - default: { - } - } - } - } - // operations on all 1s - if (constRight == -1LL) { - if (binary->op == Abstract::getBinary(type, Abstract::And)) { - // x & -1 ==> x - return binary->left; - } else if (binary->op == Abstract::getBinary(type, Abstract::Or) && - !EffectAnalyzer(getPassOptions(), features, binary->left) - .hasSideEffects()) { - // x | -1 ==> -1 - return binary->right; - } else if (binary->op == Abstract::getBinary(type, Abstract::RemS) && - !EffectAnalyzer(getPassOptions(), features, binary->left) - .hasSideEffects()) { - // (signed)x % -1 ==> 0 - right->value = Literal::makeSingleZero(type); - return right; - } else if (binary->op == Abstract::getBinary(type, Abstract::GtU) && - !EffectAnalyzer(getPassOptions(), features, binary->left) - .hasSideEffects()) { - // (unsigned)x > -1 ==> 0 - right->value = Literal::makeSingleZero(Type::i32); - right->type = Type::i32; - return right; - } else if (binary->op == Abstract::getBinary(type, Abstract::LtU)) { - // (unsigned)x < -1 ==> x != -1 - // friendlier to JS emitting as we don't need to write an unsigned - // -1 value which is large. - binary->op = Abstract::getBinary(type, Abstract::Ne); - return binary; - } else if (binary->op == DivUInt32) { - // (unsigned)x / -1 ==> x == -1 - binary->op = Abstract::getBinary(type, Abstract::Eq); - return binary; - } else if (binary->op == Abstract::getBinary(type, Abstract::Mul)) { - // x * -1 ==> 0 - x - binary->op = Abstract::getBinary(type, Abstract::Sub); - right->value = Literal::makeSingleZero(type); - std::swap(binary->left, binary->right); - return binary; - } else if (binary->op == Abstract::getBinary(type, Abstract::LeU) && - !EffectAnalyzer(getPassOptions(), features, binary->left) - .hasSideEffects()) { - // (unsigned)x <= -1 ==> 1 - right->value = Literal::makeFromInt32(1, Type::i32); - right->type = Type::i32; - return right; - } - } - // wasm binary encoding uses signed LEBs, which slightly favor negative + Expression* optimizeWithConstantOnRight(Binary* curr) { + using namespace Match; + Builder builder(*getModule()); + Expression* left; + auto* right = curr->right->cast<Const>(); + auto type = curr->right->type; + + // Operations on zero + if (matches(curr, binary(Abstract::Shl, any(&left), ival(0))) || + matches(curr, binary(Abstract::ShrU, any(&left), ival(0))) || + matches(curr, binary(Abstract::ShrS, any(&left), ival(0))) || + matches(curr, binary(Abstract::Or, any(&left), ival(0))) || + matches(curr, binary(Abstract::Xor, any(&left), ival(0)))) { + return left; + } + if (matches(curr, binary(Abstract::Mul, pure(&left), ival(0))) || + matches(curr, binary(Abstract::And, pure(&left), ival(0)))) { + return right; + } + // x == 0 ==> eqz x + if ((matches(curr, binary(Abstract::Eq, any(&left), ival(0))))) { + return builder.makeUnary(EqZInt64, left); + } + + // Operations on one + // (signed)x % 1 ==> 0 + if (matches(curr, binary(Abstract::RemS, pure(&left), ival(1)))) { + right->value = Literal::makeSingleZero(type); + return right; + } + // bool(x) | 1 ==> 1 + if (matches(curr, binary(Abstract::Or, pure(&left), ival(1))) && + Bits::getMaxBits(left, this) == 1) { + return right; + } + // bool(x) & 1 ==> bool(x) + if (matches(curr, binary(Abstract::And, any(&left), ival(1))) && + Bits::getMaxBits(left, this) == 1) { + return left; + } + // bool(x) == 1 ==> bool(x) + if (matches(curr, binary(EqInt32, any(&left), i32(1))) && + Bits::getMaxBits(left, this) == 1) { + return left; + } + // i64(bool(x)) == 1 ==> i32(bool(x)) + if (matches(curr, binary(EqInt64, any(&left), i64(1))) && + Bits::getMaxBits(left, this) == 1) { + return builder.makeUnary(WrapInt64, left); + } + // bool(x) != 1 ==> !bool(x) + if (matches(curr, binary(Abstract::Ne, any(&left), ival(1))) && + Bits::getMaxBits(curr->left, this) == 1) { + return builder.makeUnary(Abstract::getUnary(type, Abstract::EqZ), left); + } + + // Operations on all 1s + // x & -1 ==> x + if (matches(curr, binary(Abstract::And, any(&left), ival(-1)))) { + return left; + } + // x | -1 ==> -1 + if (matches(curr, binary(Abstract::Or, pure(&left), ival(-1)))) { + return right; + } + // (signed)x % -1 ==> 0 + if (matches(curr, binary(Abstract::RemS, pure(&left), ival(-1)))) { + right->value = Literal::makeSingleZero(type); + return right; + } + // (unsigned)x > -1 ==> 0 + if (matches(curr, binary(Abstract::GtU, pure(&left), ival(-1)))) { + right->value = Literal::makeSingleZero(Type::i32); + right->type = Type::i32; + return right; + } + // (unsigned)x < -1 ==> x != -1 + // Friendlier to JS emitting as we don't need to write an unsigned -1 value + // which is large. + if (matches(curr, binary(Abstract::LtU, any(), ival(-1)))) { + curr->op = Abstract::getBinary(type, Abstract::Ne); + return curr; + } + // (unsigned)x / -1 ==> x == -1 + // TODO: i64 as well if sign-extension is enabled + if (matches(curr, binary(DivUInt32, any(), ival(-1)))) { + curr->op = Abstract::getBinary(type, Abstract::Eq); + return curr; + } + // x * -1 ==> 0 - x + if (matches(curr, binary(Abstract::Mul, any(&left), ival(-1)))) { + right->value = Literal::makeSingleZero(type); + curr->op = Abstract::getBinary(type, Abstract::Sub); + curr->left = right; + curr->right = left; + return curr; + } + // (unsigned)x <= -1 ==> 1 + if (matches(curr, binary(Abstract::LeU, pure(&left), ival(-1)))) { + right->value = Literal::makeFromInt32(1, Type::i32); + right->type = Type::i32; + return right; + } + { + // Wasm binary encoding uses signed LEBs, which slightly favor negative // numbers: -64 is more efficient than +64 etc., as well as other powers - // of two 7 bits etc. higher. we therefore prefer x - -64 over x + 64. - // in theory we could just prefer negative numbers over positive, but - // that can have bad effects on gzip compression (as it would mean more - // subtractions than the more common additions). - if (binary->op == Abstract::getBinary(type, Abstract::Add) || - binary->op == Abstract::getBinary(type, Abstract::Sub)) { - auto value = constRight; - if (value == 0x40 || value == 0x2000 || value == 0x100000 || - value == 0x8000000 || value == 0x400000000LL || - value == 0x20000000000LL || value == 0x1000000000000LL || - value == 0x80000000000000LL || value == 0x4000000000000000LL) { - right->value = right->value.neg(); - if (binary->op == Abstract::getBinary(type, Abstract::Add)) { - binary->op = Abstract::getBinary(type, Abstract::Sub); - } else { - binary->op = Abstract::getBinary(type, Abstract::Add); - } - return binary; + // of two 7 bits etc. higher. we therefore prefer x - -64 over x + 64. in + // theory we could just prefer negative numbers over positive, but that + // can have bad effects on gzip compression (as it would mean more + // subtractions than the more common additions). TODO: Simplify this by + // adding an ival matcher than can bind int64_t vars. + int64_t value; + if ((matches(curr, binary(Abstract::Add, any(), ival(&value))) || + matches(curr, binary(Abstract::Sub, any(), ival(&value)))) && + (value == 0x40 || value == 0x2000 || value == 0x100000 || + value == 0x8000000 || value == 0x400000000LL || + value == 0x20000000000LL || value == 0x1000000000000LL || + value == 0x80000000000000LL || value == 0x4000000000000000LL)) { + right->value = right->value.neg(); + if (matches(curr, binary(Abstract::Add, any(), constant()))) { + curr->op = Abstract::getBinary(type, Abstract::Sub); + } else { + curr->op = Abstract::getBinary(type, Abstract::Add); } + return curr; } } - if (type.isFloat()) { - auto value = right->value.getFloat(); - if (value == 0.0) { - if (binary->op == Abstract::getBinary(type, Abstract::Sub)) { - if (std::signbit(value)) { - // x - (-0.0) ==> x + 0.0 - binary->op = Abstract::getBinary(type, Abstract::Add); - right->value = right->value.neg(); - return binary; - } else { - // x - 0.0 ==> x - return binary->left; - } - } else if (binary->op == Abstract::getBinary(type, Abstract::Add)) { - if (std::signbit(value)) { - // x + (-0.0) ==> x - return binary->left; - } + { + double value; + if (matches(curr, binary(Abstract::Sub, any(), fval(&value))) && + value == 0.0) { + // x - (-0.0) ==> x + 0.0 + if (std::signbit(value)) { + curr->op = Abstract::getBinary(type, Abstract::Add); + right->value = right->value.neg(); + return curr; + } else { + // x - 0.0 ==> x + return curr->left; } } } - if (type.isInteger() || type.isFloat()) { - // note that this is correct even on floats with a NaN on the left, - // as a NaN would skip the computation and just return the NaN, - // and that is precisely what we do here. but, the same with -1 - // (change to a negation) would be incorrect for that reason. - if (right->value == Literal::makeFromInt32(1, type)) { - if (binary->op == Abstract::getBinary(type, Abstract::Mul) || - binary->op == Abstract::getBinary(type, Abstract::DivS) || - binary->op == Abstract::getBinary(type, Abstract::DivU)) { - return binary->left; - } + { + // x + (-0.0) ==> x + double value; + if (matches(curr, binary(Abstract::Add, any(), fval(&value))) && + value == 0.0 && std::signbit(value)) { + return curr->left; } } - // TODO: v128 not implemented yet + // Note that this is correct even on floats with a NaN on the left, + // as a NaN would skip the computation and just return the NaN, + // and that is precisely what we do here. but, the same with -1 + // (change to a negation) would be incorrect for that reason. + if (matches(curr, binary(Abstract::Mul, any(&left), constant(1))) || + matches(curr, binary(Abstract::DivS, any(&left), constant(1))) || + matches(curr, binary(Abstract::DivU, any(&left), constant(1)))) { + return left; + } return nullptr; } @@ -1397,9 +1415,7 @@ private: if ((binary->op == Abstract::getBinary(type, Abstract::Shl) || binary->op == Abstract::getBinary(type, Abstract::ShrU) || binary->op == Abstract::getBinary(type, Abstract::ShrS)) && - !EffectAnalyzer( - getPassOptions(), getModule()->features, binary->right) - .hasSideEffects()) { + !effects(binary->right).hasSideEffects()) { return binary->left; } } diff --git a/test/example/match.cpp b/test/example/match.cpp new file mode 100644 index 000000000..e485f2355 --- /dev/null +++ b/test/example/match.cpp @@ -0,0 +1,448 @@ +#include <cassert> +#include <iostream> + +#include "literal.h" +#include "wasm-builder.h" +#include <ir/match.h> + +using namespace wasm; +using namespace wasm::Match; + +Module module; +Builder builder(module); + +void test_internal_any() { + std::cout << "Testing Internal::Any\n"; + + assert(Internal::Any<int32_t>(nullptr).matches(0)); + assert(Internal::Any<int32_t>(nullptr).matches(1)); + assert(Internal::Any<int32_t>(nullptr).matches(-1)); + assert(Internal::Any<int32_t>(nullptr).matches(42LL)); + assert(Internal::Any<int32_t>(nullptr).matches(4.2f)); + + assert(Internal::Any<int64_t>(nullptr).matches(0)); + assert(Internal::Any<int64_t>(nullptr).matches(1)); + assert(Internal::Any<int64_t>(nullptr).matches(-1)); + assert(Internal::Any<int64_t>(nullptr).matches(42LL)); + assert(Internal::Any<int64_t>(nullptr).matches(4.2f)); + + assert(Internal::Any<float>(nullptr).matches(0)); + assert(Internal::Any<float>(nullptr).matches(1)); + assert(Internal::Any<float>(nullptr).matches(-1)); + assert(Internal::Any<float>(nullptr).matches(42LL)); + assert(Internal::Any<float>(nullptr).matches(4.2f)); + + assert(Internal::Any<double>(nullptr).matches(0)); + assert(Internal::Any<double>(nullptr).matches(1)); + assert(Internal::Any<double>(nullptr).matches(-1)); + assert(Internal::Any<double>(nullptr).matches(42LL)); + assert(Internal::Any<double>(nullptr).matches(4.2f)); + + // Working as intended: cannot convert `const char [6]' to double + // assert(Internal::Any<double>(nullptr).matches("hello")); + + { + int32_t val = 0xffffffff; + assert(Internal::Any<int32_t>(&val).matches(0)); + assert(val == 0); + assert(Internal::Any<int32_t>(&val).matches(1)); + assert(val == 1); + assert(Internal::Any<int32_t>(&val).matches(-1)); + assert(val == -1); + assert(Internal::Any<int32_t>(&val).matches(42LL)); + assert(val == 42); + assert(Internal::Any<int32_t>(&val).matches(4.2f)); + assert(val == 4); + } + + { + Expression* expr = nullptr; + Nop* nop = nullptr; + + Expression* builtExpr = builder.makeNop(); + Nop* builtNop = builder.makeNop(); + AtomicFence* builtFence = builder.makeAtomicFence(); + + assert(Internal::Any(&expr).matches(builtExpr)); + assert(expr == builtExpr); + + assert(Internal::Any(&expr).matches(builtNop)); + assert(expr == builtNop); + + assert(Internal::Any(&expr).matches(builtFence)); + assert(expr == builtFence); + + assert(Internal::Any(&nop).matches(builtExpr)); + assert(nop == builtExpr); + + assert(Internal::Any(&nop).matches(builtNop)); + assert(nop == builtNop); + + // Does NOT match sibling expression types. Bound variable unchanged. + assert(!Internal::Any(&nop).matches(builtFence)); + assert(nop == builtNop); + + // Working as intended: invalid conversion from Expression** to Nop** + // assert(Internal::Any<Nop*>(&expr).matches(builtExpr)); + } +} + +void test_internal_exact() { + std::cout << "Testing Internal::Exact\n"; + + assert(Internal::Exact<int32_t>(nullptr, 0).matches(0)); + assert(Internal::Exact<int32_t>(nullptr, 1).matches(1)); + assert(Internal::Exact<int32_t>(nullptr, -1).matches(-1)); + assert(Internal::Exact<int32_t>(nullptr, 42).matches(42LL)); + assert(Internal::Exact<int32_t>(nullptr, 4).matches(4.2f)); + + assert(!Internal::Exact<int32_t>(nullptr, 1).matches(0)); + assert(!Internal::Exact<int32_t>(nullptr, -1).matches(1)); + assert(!Internal::Exact<int32_t>(nullptr, 42).matches(-1)); + assert(!Internal::Exact<int32_t>(nullptr, 4).matches(42LL)); + assert(!Internal::Exact<int32_t>(nullptr, 0).matches(4.2f)); + + { + Expression* expr = nullptr; + Nop* nop = nullptr; + + Nop* builtNop = builder.makeNop(); + Expression* builtExpr = builtNop; + + assert(!Internal::Exact(&expr, expr).matches(builtExpr)); + assert(Internal::Exact(&expr, builtExpr).matches(builtExpr)); + assert(expr == builtExpr); + + assert(!Internal::Exact(&nop, nop).matches(builtNop)); + assert(Internal::Exact(&nop, builtNop).matches(builtNop)); + assert(nop == builtNop); + nop = nullptr; + assert(Internal::Exact(&nop, builtNop).matches(builtExpr)); + assert(nop == builtNop); + } +} + +void test_internal_literal() { + std::cout << "Testing Internal::{I32,I64,Int,F32,F64,Float}Lit\n"; + + Literal i32Zero(int32_t(0)); + Literal i32One(int32_t(1)); + Literal f32Zero(float(0)); + Literal f32One(float(1)); + Literal i64Zero(int64_t(0)); + Literal i64One(int64_t(1)); + Literal f64Zero(double(0)); + Literal f64One(double(1)); + + auto anyi32 = Internal::I32Lit(nullptr, Internal::Any<int32_t>(nullptr)); + assert(anyi32.matches(i32Zero)); + assert(anyi32.matches(i32One)); + assert(!anyi32.matches(f32Zero)); + assert(!anyi32.matches(f32One)); + assert(!anyi32.matches(i64Zero)); + assert(!anyi32.matches(i64One)); + assert(!anyi32.matches(f64Zero)); + assert(!anyi32.matches(f64One)); + + auto onei32 = Internal::I32Lit(nullptr, Internal::Exact<int32_t>(nullptr, 1)); + assert(!onei32.matches(i32Zero)); + assert(onei32.matches(i32One)); + assert(!onei32.matches(f32Zero)); + assert(!onei32.matches(f32One)); + assert(!onei32.matches(i64Zero)); + assert(!onei32.matches(i64One)); + assert(!onei32.matches(f64Zero)); + assert(!onei32.matches(f64One)); + + auto anyi64 = Internal::I64Lit(nullptr, Internal::Any<int64_t>(nullptr)); + assert(!anyi64.matches(i32Zero)); + assert(!anyi64.matches(i32One)); + assert(!anyi64.matches(f32Zero)); + assert(!anyi64.matches(f32One)); + assert(anyi64.matches(i64Zero)); + assert(anyi64.matches(i64One)); + assert(!anyi64.matches(f64Zero)); + assert(!anyi64.matches(f64One)); + + auto onei64 = Internal::I64Lit(nullptr, Internal::Exact<int64_t>(nullptr, 1)); + assert(!onei64.matches(i32Zero)); + assert(!onei64.matches(i32One)); + assert(!onei64.matches(f32Zero)); + assert(!onei64.matches(f32One)); + assert(!onei64.matches(i64Zero)); + assert(onei64.matches(i64One)); + assert(!onei64.matches(f64Zero)); + assert(!onei64.matches(f64One)); + + auto anyint = Internal::IntLit(nullptr, Internal::Any<int64_t>(nullptr)); + assert(anyint.matches(i32Zero)); + assert(anyint.matches(i32One)); + assert(!anyint.matches(f32Zero)); + assert(!anyint.matches(f32One)); + assert(anyint.matches(i64Zero)); + assert(anyint.matches(i64One)); + assert(!anyint.matches(f64Zero)); + assert(!anyint.matches(f64One)); + + auto oneint = Internal::IntLit(nullptr, Internal::Exact<int64_t>(nullptr, 1)); + assert(!oneint.matches(i32Zero)); + assert(oneint.matches(i32One)); + assert(!oneint.matches(f32Zero)); + assert(!oneint.matches(f32One)); + assert(!oneint.matches(i64Zero)); + assert(oneint.matches(i64One)); + assert(!oneint.matches(f64Zero)); + assert(!oneint.matches(f64One)); + + auto anyf32 = Internal::F32Lit(nullptr, Internal::Any<float>(nullptr)); + assert(!anyf32.matches(i32Zero)); + assert(!anyf32.matches(i32One)); + assert(anyf32.matches(f32Zero)); + assert(anyf32.matches(f32One)); + assert(!anyf32.matches(i64Zero)); + assert(!anyf32.matches(i64One)); + assert(!anyf32.matches(f64Zero)); + assert(!anyf32.matches(f64One)); + + auto onef32 = Internal::F32Lit(nullptr, Internal::Exact<float>(nullptr, 1)); + assert(!onef32.matches(i32Zero)); + assert(!onef32.matches(i32One)); + assert(!onef32.matches(f32Zero)); + assert(onef32.matches(f32One)); + assert(!onef32.matches(i64Zero)); + assert(!onef32.matches(i64One)); + assert(!onef32.matches(f64Zero)); + assert(!onef32.matches(f64One)); + + auto anyf64 = Internal::F64Lit(nullptr, Internal::Any<double>(nullptr)); + assert(!anyf64.matches(i32Zero)); + assert(!anyf64.matches(i32One)); + assert(!anyf64.matches(f32Zero)); + assert(!anyf64.matches(f32One)); + assert(!anyf64.matches(i64Zero)); + assert(!anyf64.matches(i64One)); + assert(anyf64.matches(f64Zero)); + assert(anyf64.matches(f64One)); + + auto onef64 = Internal::F64Lit(nullptr, Internal::Exact<double>(nullptr, 1)); + assert(!onef64.matches(i32Zero)); + assert(!onef64.matches(i32One)); + assert(!onef64.matches(f32Zero)); + assert(!onef64.matches(f32One)); + assert(!onef64.matches(i64Zero)); + assert(!onef64.matches(i64One)); + assert(!onef64.matches(f64Zero)); + assert(onef64.matches(f64One)); + + auto anyfp = Internal::FloatLit(nullptr, Internal::Any<double>(nullptr)); + assert(!anyfp.matches(i32Zero)); + assert(!anyfp.matches(i32One)); + assert(anyfp.matches(f32Zero)); + assert(anyfp.matches(f32One)); + assert(!anyfp.matches(i64Zero)); + assert(!anyfp.matches(i64One)); + assert(anyfp.matches(f64Zero)); + assert(anyfp.matches(f64One)); + + auto onefp = Internal::FloatLit(nullptr, Internal::Exact<double>(nullptr, 1)); + assert(!onefp.matches(i32Zero)); + assert(!onefp.matches(i32One)); + assert(!onefp.matches(f32Zero)); + assert(onefp.matches(f32One)); + assert(!onefp.matches(i64Zero)); + assert(!onefp.matches(i64One)); + assert(!onefp.matches(f64Zero)); + assert(onefp.matches(f64One)); + + auto number = Internal::NumberLit(nullptr, 1); + assert(!number.matches(i32Zero)); + assert(number.matches(i32One)); + assert(!number.matches(f32Zero)); + assert(number.matches(f32One)); + assert(!number.matches(i64Zero)); + assert(number.matches(i64One)); + assert(!number.matches(f64Zero)); + assert(number.matches(f64One)); + + int64_t x = 0; + Literal xLit; + Literal imatched(int32_t(42)); + assert(Internal::IntLit(&xLit, Internal::Any(&x)).matches(imatched)); + assert(xLit == imatched); + assert(x == 42); + + double f = 0; + Literal fLit; + Literal fmatched(double(42)); + assert(Internal::FloatLit(&fLit, Internal::Any(&f)).matches(fmatched)); + assert(fLit == fmatched); + assert(f == 42.0); + + Literal numLit; + Literal numMatched(1.0f); + assert(Internal::NumberLit(&numLit, 1).matches(numMatched)); + assert(numLit == numMatched); +} + +void test_internal_const() { + std::cout << "Testing Internal::ConstantMatcher\n"; + + Const* c = builder.makeConst(Literal(int32_t(42))); + Expression* constExpr = builder.makeConst(Literal(int32_t(43))); + Expression* nop = builder.makeNop(); + + Const* extractedConst = nullptr; + Literal extractedLit; + int32_t extractedInt = 0; + + auto matcher = Internal::ConstMatcher( + &extractedConst, + Internal::I32Lit(&extractedLit, Internal::Any(&extractedInt))); + + assert(matcher.matches(c)); + assert(extractedConst == c); + assert(extractedLit == Literal(int32_t(42))); + assert(extractedInt == 42); + + assert(matcher.matches(constExpr)); + assert(extractedConst == constExpr); + assert(extractedLit == Literal(int32_t(43))); + assert(extractedInt == 43); + + assert(!matcher.matches(nop)); +} + +void test_internal_unary() { + Expression* eqz32 = + builder.makeUnary(EqZInt32, builder.makeConst(Literal(int32_t(0)))); + Expression* eqz64 = + builder.makeUnary(EqZInt64, builder.makeConst(Literal(int64_t(0)))); + Expression* clz = + builder.makeUnary(ClzInt32, builder.makeConst(Literal(int32_t(0)))); + Expression* nop = builder.makeNop(); + + std::cout << "Testing Internal::UnaryMatcher\n"; + + Unary* out = nullptr; + + auto eqz32Matcher = + Internal::UnaryMatcher(&out, EqZInt32, Internal::Any<Expression*>(nullptr)); + assert(eqz32Matcher.matches(eqz32)); + assert(out == eqz32); + assert(!eqz32Matcher.matches(eqz64)); + assert(!eqz32Matcher.matches(clz)); + assert(!eqz32Matcher.matches(nop)); + + std::cout << "Testing Internal::AbstractUnaryMatcher\n"; + + out = nullptr; + + auto eqzMatcher = Internal::AbstractUnaryMatcher( + &out, Abstract::EqZ, Internal::Any<Expression*>(nullptr)); + assert(eqzMatcher.matches(eqz32)); + assert(out == eqz32); + assert(eqzMatcher.matches(eqz64)); + assert(out == eqz64); + assert(!eqzMatcher.matches(clz)); + assert(!eqzMatcher.matches(nop)); +} + +void test_internal_binary() { + Expression* eq32 = builder.makeBinary(EqInt32, + builder.makeConst(Literal(int32_t(0))), + builder.makeConst(Literal(int32_t(0)))); + Expression* eq64 = builder.makeBinary(EqInt64, + builder.makeConst(Literal(int64_t(0))), + builder.makeConst(Literal(int64_t(0)))); + Expression* add = builder.makeBinary(AddInt32, + builder.makeConst(Literal(int32_t(0))), + builder.makeConst(Literal(int32_t(0)))); + Expression* nop = builder.makeNop(); + + std::cout << "Testing Internal::BinaryMatcher\n"; + + Binary* out = nullptr; + + auto eq32Matcher = + Internal::BinaryMatcher(&out, + EqInt32, + Internal::Any<Expression*>(nullptr), + Internal::Any<Expression*>(nullptr)); + assert(eq32Matcher.matches(eq32)); + assert(out == eq32); + assert(!eq32Matcher.matches(eq64)); + assert(!eq32Matcher.matches(add)); + assert(!eq32Matcher.matches(nop)); + + std::cout << "Testing Internal::AbstractBinaryMatcher\n"; + + out = nullptr; + + auto eqMatcher = + Internal::AbstractBinaryMatcher(&out, + Abstract::Eq, + Internal::Any<Expression*>(nullptr), + Internal::Any<Expression*>(nullptr)); + assert(eqMatcher.matches(eq32)); + assert(out == eq32); + assert(eqMatcher.matches(eq64)); + assert(out == eq64); + assert(!eqMatcher.matches(add)); + assert(!eqMatcher.matches(nop)); +} + +void test_internal_select() { + std::cout << "Testing Internal::SelectMatcher\n"; + + auto zero = [&]() { return builder.makeConst(Literal(int32_t(0))); }; + auto one = [&]() { return builder.makeConst(Literal(int32_t(1))); }; + + auto constMatcher = [](int32_t c) { + return Internal::ConstMatcher( + nullptr, Internal::I32Lit(nullptr, Internal::Exact<int32_t>(nullptr, c))); + }; + + // NB: `makeSelect` takes the condition first for some reason + Expression* leftOne = builder.makeSelect(zero(), one(), zero()); + Expression* rightOne = builder.makeSelect(zero(), zero(), one()); + Expression* condOne = builder.makeSelect(one(), zero(), zero()); + + Select* out = nullptr; + + auto zeroesMatcher = Internal::SelectMatcher( + &out, constMatcher(0), constMatcher(0), constMatcher(0)); + assert(!zeroesMatcher.matches(leftOne)); + assert(!zeroesMatcher.matches(rightOne)); + assert(!zeroesMatcher.matches(condOne)); + + auto leftMatcher = Internal::SelectMatcher( + &out, constMatcher(1), constMatcher(0), constMatcher(0)); + assert(leftMatcher.matches(leftOne)); + assert(out == leftOne); + assert(!leftMatcher.matches(rightOne)); + assert(!leftMatcher.matches(condOne)); + + auto rightMatcher = Internal::SelectMatcher( + &out, constMatcher(0), constMatcher(1), constMatcher(0)); + assert(!rightMatcher.matches(leftOne)); + assert(rightMatcher.matches(rightOne)); + assert(out == rightOne); + assert(!rightMatcher.matches(condOne)); + + auto condMatcher = Internal::SelectMatcher( + &out, constMatcher(0), constMatcher(0), constMatcher(1)); + assert(!condMatcher.matches(leftOne)); + assert(!condMatcher.matches(rightOne)); + assert(condMatcher.matches(condOne)); + assert(out == condOne); +} + +int main() { + test_internal_any(); + test_internal_exact(); + test_internal_literal(); + test_internal_const(); + test_internal_unary(); + test_internal_binary(); + test_internal_select(); +} diff --git a/test/example/match.txt b/test/example/match.txt new file mode 100644 index 000000000..697d9e9f0 --- /dev/null +++ b/test/example/match.txt @@ -0,0 +1,9 @@ +Testing Internal::Any +Testing Internal::Exact +Testing Internal::{I32,I64,Int,F32,F64,Float}Lit +Testing Internal::ConstantMatcher +Testing Internal::UnaryMatcher +Testing Internal::AbstractUnaryMatcher +Testing Internal::BinaryMatcher +Testing Internal::AbstractBinaryMatcher +Testing Internal::SelectMatcher |