diff options
-rw-r--r-- | CMakeLists.txt | 15 | ||||
-rw-r--r-- | src/ast/CMakeLists.txt | 5 | ||||
-rw-r--r-- | src/ast/ExpressionAnalyzer.cpp | 495 | ||||
-rw-r--r-- | src/ast/ExpressionManipulator.cpp | 152 | ||||
-rw-r--r-- | src/ast_utils.h | 616 | ||||
-rw-r--r-- | src/passes/DuplicateFunctionElimination.cpp | 8 | ||||
-rw-r--r-- | src/passes/OptimizeInstructions.cpp | 92 |
7 files changed, 725 insertions, 658 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 74471d1e5..e557b289c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -118,6 +118,7 @@ IF (UNIX AND ENDIF() # Static libraries +ADD_SUBDIRECTORY(src/ast) ADD_SUBDIRECTORY(src/asmjs) ADD_SUBDIRECTORY(src/emscripten-optimizer) ADD_SUBDIRECTORY(src/passes) @@ -136,7 +137,7 @@ IF(BUILD_STATIC_LIB) ELSE() ADD_LIBRARY(binaryen SHARED ${binaryen_SOURCES}) ENDIF() -TARGET_LINK_LIBRARIES(binaryen ${all_passes} wasm asmjs support) +TARGET_LINK_LIBRARIES(binaryen ${all_passes} wasm asmjs ast support) INSTALL(TARGETS binaryen DESTINATION lib) INSTALL(FILES src/binaryen-c.h DESTINATION include) @@ -150,7 +151,7 @@ SET(wasm-shell_SOURCES ) ADD_EXECUTABLE(wasm-shell ${wasm-shell_SOURCES}) -TARGET_LINK_LIBRARIES(wasm-shell wasm asmjs emscripten-optimizer ${all_passes} support) +TARGET_LINK_LIBRARIES(wasm-shell wasm asmjs emscripten-optimizer ${all_passes} ast support) SET_PROPERTY(TARGET wasm-shell PROPERTY CXX_STANDARD 11) SET_PROPERTY(TARGET wasm-shell PROPERTY CXX_STANDARD_REQUIRED ON) INSTALL(TARGETS wasm-shell DESTINATION bin) @@ -160,7 +161,7 @@ SET(wasm-opt_SOURCES ) ADD_EXECUTABLE(wasm-opt ${wasm-opt_SOURCES}) -TARGET_LINK_LIBRARIES(wasm-opt wasm asmjs emscripten-optimizer ${all_passes} support) +TARGET_LINK_LIBRARIES(wasm-opt wasm asmjs emscripten-optimizer ${all_passes} ast support) SET_PROPERTY(TARGET wasm-opt PROPERTY CXX_STANDARD 11) SET_PROPERTY(TARGET wasm-opt PROPERTY CXX_STANDARD_REQUIRED ON) INSTALL(TARGETS wasm-opt DESTINATION bin) @@ -171,7 +172,7 @@ SET(asm2wasm_SOURCES ) ADD_EXECUTABLE(asm2wasm ${asm2wasm_SOURCES}) -TARGET_LINK_LIBRARIES(asm2wasm emscripten-optimizer ${all_passes} wasm asmjs support) +TARGET_LINK_LIBRARIES(asm2wasm emscripten-optimizer ${all_passes} wasm asmjs ast support) SET_PROPERTY(TARGET asm2wasm PROPERTY CXX_STANDARD 11) SET_PROPERTY(TARGET asm2wasm PROPERTY CXX_STANDARD_REQUIRED ON) INSTALL(TARGETS asm2wasm DESTINATION bin) @@ -183,7 +184,7 @@ SET(s2wasm_SOURCES ) ADD_EXECUTABLE(s2wasm ${s2wasm_SOURCES}) -TARGET_LINK_LIBRARIES(s2wasm passes wasm asmjs support) +TARGET_LINK_LIBRARIES(s2wasm passes wasm asmjs ast support) SET_PROPERTY(TARGET s2wasm PROPERTY CXX_STANDARD 11) SET_PROPERTY(TARGET s2wasm PROPERTY CXX_STANDARD_REQUIRED ON) INSTALL(TARGETS s2wasm DESTINATION bin) @@ -193,7 +194,7 @@ SET(wasm_as_SOURCES ) ADD_EXECUTABLE(wasm-as ${wasm_as_SOURCES}) -TARGET_LINK_LIBRARIES(wasm-as wasm asmjs passes support) +TARGET_LINK_LIBRARIES(wasm-as wasm asmjs passes ast support) SET_PROPERTY(TARGET wasm-as PROPERTY CXX_STANDARD 11) SET_PROPERTY(TARGET wasm-as PROPERTY CXX_STANDARD_REQUIRED ON) INSTALL(TARGETS wasm-as DESTINATION bin) @@ -203,7 +204,7 @@ SET(wasm_dis_SOURCES ) ADD_EXECUTABLE(wasm-dis ${wasm_dis_SOURCES}) -TARGET_LINK_LIBRARIES(wasm-dis passes wasm asmjs support) +TARGET_LINK_LIBRARIES(wasm-dis passes wasm asmjs ast support) SET_PROPERTY(TARGET wasm-dis PROPERTY CXX_STANDARD 11) SET_PROPERTY(TARGET wasm-dis PROPERTY CXX_STANDARD_REQUIRED ON) INSTALL(TARGETS wasm-dis DESTINATION bin) diff --git a/src/ast/CMakeLists.txt b/src/ast/CMakeLists.txt new file mode 100644 index 000000000..e48e84eed --- /dev/null +++ b/src/ast/CMakeLists.txt @@ -0,0 +1,5 @@ +SET(ast_SOURCES + ExpressionAnalyzer.cpp + ExpressionManipulator.cpp +) +ADD_LIBRARY(ast STATIC ${ast_SOURCES}) diff --git a/src/ast/ExpressionAnalyzer.cpp b/src/ast/ExpressionAnalyzer.cpp new file mode 100644 index 000000000..421206706 --- /dev/null +++ b/src/ast/ExpressionAnalyzer.cpp @@ -0,0 +1,495 @@ +/* + * Copyright 2016 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ast_utils.h" +#include "support/hash.h" + + +namespace wasm { +// Given a stack of expressions, checks if the topmost is used as a result. +// For example, if the parent is a block and the node is before the last position, +// it is not used. +bool ExpressionAnalyzer::isResultUsed(std::vector<Expression*> stack, Function* func) { + for (int i = int(stack.size()) - 2; i >= 0; i--) { + auto* curr = stack[i]; + auto* above = stack[i + 1]; + // only if and block can drop values (pre-drop expression was added) FIXME + if (curr->is<Block>()) { + auto* block = curr->cast<Block>(); + for (size_t j = 0; j < block->list.size() - 1; j++) { + if (block->list[j] == above) return false; + } + assert(block->list.back() == above); + // continue down + } else if (curr->is<If>()) { + auto* iff = curr->cast<If>(); + if (above == iff->condition) return true; + if (!iff->ifFalse) return false; + assert(above == iff->ifTrue || above == iff->ifFalse); + // continue down + } else { + if (curr->is<Drop>()) return false; + return true; // all other node types use the result + } + } + // The value might be used, so it depends on if the function returns + return func->result != none; +} + +// Checks if a value is dropped. +bool ExpressionAnalyzer::isResultDropped(std::vector<Expression*> stack) { + for (int i = int(stack.size()) - 2; i >= 0; i--) { + auto* curr = stack[i]; + auto* above = stack[i + 1]; + if (curr->is<Block>()) { + auto* block = curr->cast<Block>(); + for (size_t j = 0; j < block->list.size() - 1; j++) { + if (block->list[j] == above) return false; + } + assert(block->list.back() == above); + // continue down + } else if (curr->is<If>()) { + auto* iff = curr->cast<If>(); + if (above == iff->condition) return false; + if (!iff->ifFalse) return false; + assert(above == iff->ifTrue || above == iff->ifFalse); + // continue down + } else { + if (curr->is<Drop>()) return true; // dropped + return false; // all other node types use the result + } + } + return false; +} + + +bool ExpressionAnalyzer::flexibleEqual(Expression* left, Expression* right, ExprComparer comparer) { + std::vector<Name> nameStack; + std::map<Name, std::vector<Name>> rightNames; // for each name on the left, the stack of names on the right (a stack, since names are scoped and can nest duplicatively + Nop popNameMarker; + std::vector<Expression*> leftStack; + std::vector<Expression*> rightStack; + + auto noteNames = [&](Name left, Name right) { + if (left.is() != right.is()) return false; + if (left.is()) { + nameStack.push_back(left); + rightNames[left].push_back(right); + leftStack.push_back(&popNameMarker); + rightStack.push_back(&popNameMarker); + } + return true; + }; + auto checkNames = [&](Name left, Name right) { + auto iter = rightNames.find(left); + if (iter == rightNames.end()) return left == right; // non-internal name + return iter->second.back() == right; + }; + auto popName = [&]() { + auto left = nameStack.back(); + nameStack.pop_back(); + rightNames[left].pop_back(); + }; + + leftStack.push_back(left); + rightStack.push_back(right); + + while (leftStack.size() > 0 && rightStack.size() > 0) { + left = leftStack.back(); + leftStack.pop_back(); + right = rightStack.back(); + rightStack.pop_back(); + if (!left != !right) return false; + if (!left) continue; + if (left == &popNameMarker) { + popName(); + continue; + } + if (comparer(left, right)) continue; // comparison hook, before all the rest + // continue with normal structural comparison + if (left->_id != right->_id) return false; + #define PUSH(clazz, what) \ + leftStack.push_back(left->cast<clazz>()->what); \ + rightStack.push_back(right->cast<clazz>()->what); + #define CHECK(clazz, what) \ + if (left->cast<clazz>()->what != right->cast<clazz>()->what) return false; + switch (left->_id) { + case Expression::Id::BlockId: { + if (!noteNames(left->cast<Block>()->name, right->cast<Block>()->name)) return false; + CHECK(Block, list.size()); + for (Index i = 0; i < left->cast<Block>()->list.size(); i++) { + PUSH(Block, list[i]); + } + break; + } + case Expression::Id::IfId: { + PUSH(If, condition); + PUSH(If, ifTrue); + PUSH(If, ifFalse); + break; + } + case Expression::Id::LoopId: { + if (!noteNames(left->cast<Loop>()->name, right->cast<Loop>()->name)) return false; + PUSH(Loop, body); + break; + } + case Expression::Id::BreakId: { + if (!checkNames(left->cast<Break>()->name, right->cast<Break>()->name)) return false; + PUSH(Break, condition); + PUSH(Break, value); + break; + } + case Expression::Id::SwitchId: { + CHECK(Switch, targets.size()); + for (Index i = 0; i < left->cast<Switch>()->targets.size(); i++) { + if (!checkNames(left->cast<Switch>()->targets[i], right->cast<Switch>()->targets[i])) return false; + } + if (!checkNames(left->cast<Switch>()->default_, right->cast<Switch>()->default_)) return false; + PUSH(Switch, condition); + PUSH(Switch, value); + break; + } + case Expression::Id::CallId: { + CHECK(Call, target); + CHECK(Call, operands.size()); + for (Index i = 0; i < left->cast<Call>()->operands.size(); i++) { + PUSH(Call, operands[i]); + } + break; + } + case Expression::Id::CallImportId: { + CHECK(CallImport, target); + CHECK(CallImport, operands.size()); + for (Index i = 0; i < left->cast<CallImport>()->operands.size(); i++) { + PUSH(CallImport, operands[i]); + } + break; + } + case Expression::Id::CallIndirectId: { + PUSH(CallIndirect, target); + CHECK(CallIndirect, fullType); + CHECK(CallIndirect, operands.size()); + for (Index i = 0; i < left->cast<CallIndirect>()->operands.size(); i++) { + PUSH(CallIndirect, operands[i]); + } + break; + } + case Expression::Id::GetLocalId: { + CHECK(GetLocal, index); + break; + } + case Expression::Id::SetLocalId: { + CHECK(SetLocal, index); + CHECK(SetLocal, type); // for tee/set + PUSH(SetLocal, value); + break; + } + case Expression::Id::GetGlobalId: { + CHECK(GetGlobal, name); + break; + } + case Expression::Id::SetGlobalId: { + CHECK(SetGlobal, name); + PUSH(SetGlobal, value); + break; + } + case Expression::Id::LoadId: { + CHECK(Load, bytes); + CHECK(Load, signed_); + CHECK(Load, offset); + CHECK(Load, align); + PUSH(Load, ptr); + break; + } + case Expression::Id::StoreId: { + CHECK(Store, bytes); + CHECK(Store, offset); + CHECK(Store, align); + CHECK(Store, valueType); + PUSH(Store, ptr); + PUSH(Store, value); + break; + } + case Expression::Id::ConstId: { + CHECK(Const, value); + break; + } + case Expression::Id::UnaryId: { + CHECK(Unary, op); + PUSH(Unary, value); + break; + } + case Expression::Id::BinaryId: { + CHECK(Binary, op); + PUSH(Binary, left); + PUSH(Binary, right); + break; + } + case Expression::Id::SelectId: { + PUSH(Select, ifTrue); + PUSH(Select, ifFalse); + PUSH(Select, condition); + break; + } + case Expression::Id::DropId: { + PUSH(Drop, value); + break; + } + case Expression::Id::ReturnId: { + PUSH(Return, value); + break; + } + case Expression::Id::HostId: { + CHECK(Host, op); + CHECK(Host, nameOperand); + CHECK(Host, operands.size()); + for (Index i = 0; i < left->cast<Host>()->operands.size(); i++) { + PUSH(Host, operands[i]); + } + break; + } + case Expression::Id::NopId: { + break; + } + case Expression::Id::UnreachableId: { + break; + } + default: WASM_UNREACHABLE(); + } + #undef CHECK + #undef PUSH + } + if (leftStack.size() > 0 || rightStack.size() > 0) return false; + return true; +} + + +// hash an expression, ignoring superficial details like specific internal names +uint32_t ExpressionAnalyzer::hash(Expression* curr) { + uint32_t digest = 0; + + auto hash = [&digest](uint32_t hash) { + digest = rehash(digest, hash); + }; + auto hash64 = [&digest](uint64_t hash) { + digest = rehash(rehash(digest, hash >> 32), uint32_t(hash)); + }; + + std::vector<Name> nameStack; + Index internalCounter = 0; + std::map<Name, std::vector<Index>> internalNames; // for each internal name, a vector if unique ids + Nop popNameMarker; + std::vector<Expression*> stack; + + auto noteName = [&](Name curr) { + if (curr.is()) { + nameStack.push_back(curr); + internalNames[curr].push_back(internalCounter++); + stack.push_back(&popNameMarker); + } + return true; + }; + auto hashName = [&](Name curr) { + auto iter = internalNames.find(curr); + if (iter == internalNames.end()) hash64(uint64_t(curr.str)); + else hash(iter->second.back()); + }; + auto popName = [&]() { + auto curr = nameStack.back(); + nameStack.pop_back(); + internalNames[curr].pop_back(); + }; + + stack.push_back(curr); + + while (stack.size() > 0) { + curr = stack.back(); + stack.pop_back(); + if (!curr) continue; + if (curr == &popNameMarker) { + popName(); + continue; + } + hash(curr->_id); + // we often don't need to hash the type, as it is tied to other values + // we are hashing anyhow, but there are exceptions: for example, a + // get_local's type is determined by the function, so if we are + // hashing only expression fragments, then two from different + // functions may turn out the same even if the type differs. Likewise, + // if we hash between modules, then we need to take int account + // call_imports type, etc. The simplest thing is just to hash the + // type for all of them. + hash(curr->type); + + #define PUSH(clazz, what) \ + stack.push_back(curr->cast<clazz>()->what); + #define HASH(clazz, what) \ + hash(curr->cast<clazz>()->what); + #define HASH64(clazz, what) \ + hash64(curr->cast<clazz>()->what); + #define HASH_NAME(clazz, what) \ + hash64(uint64_t(curr->cast<clazz>()->what.str)); + #define HASH_PTR(clazz, what) \ + hash64(uint64_t(curr->cast<clazz>()->what)); + switch (curr->_id) { + case Expression::Id::BlockId: { + noteName(curr->cast<Block>()->name); + HASH(Block, list.size()); + for (Index i = 0; i < curr->cast<Block>()->list.size(); i++) { + PUSH(Block, list[i]); + } + break; + } + case Expression::Id::IfId: { + PUSH(If, condition); + PUSH(If, ifTrue); + PUSH(If, ifFalse); + break; + } + case Expression::Id::LoopId: { + noteName(curr->cast<Loop>()->name); + PUSH(Loop, body); + break; + } + case Expression::Id::BreakId: { + hashName(curr->cast<Break>()->name); + PUSH(Break, condition); + PUSH(Break, value); + break; + } + case Expression::Id::SwitchId: { + HASH(Switch, targets.size()); + for (Index i = 0; i < curr->cast<Switch>()->targets.size(); i++) { + hashName(curr->cast<Switch>()->targets[i]); + } + hashName(curr->cast<Switch>()->default_); + PUSH(Switch, condition); + PUSH(Switch, value); + break; + } + case Expression::Id::CallId: { + HASH_NAME(Call, target); + HASH(Call, operands.size()); + for (Index i = 0; i < curr->cast<Call>()->operands.size(); i++) { + PUSH(Call, operands[i]); + } + break; + } + case Expression::Id::CallImportId: { + HASH_NAME(CallImport, target); + HASH(CallImport, operands.size()); + for (Index i = 0; i < curr->cast<CallImport>()->operands.size(); i++) { + PUSH(CallImport, operands[i]); + } + break; + } + case Expression::Id::CallIndirectId: { + PUSH(CallIndirect, target); + HASH_NAME(CallIndirect, fullType); + HASH(CallIndirect, operands.size()); + for (Index i = 0; i < curr->cast<CallIndirect>()->operands.size(); i++) { + PUSH(CallIndirect, operands[i]); + } + break; + } + case Expression::Id::GetLocalId: { + HASH(GetLocal, index); + break; + } + case Expression::Id::SetLocalId: { + HASH(SetLocal, index); + PUSH(SetLocal, value); + break; + } + case Expression::Id::GetGlobalId: { + HASH_NAME(GetGlobal, name); + break; + } + case Expression::Id::SetGlobalId: { + HASH_NAME(SetGlobal, name); + PUSH(SetGlobal, value); + break; + } + case Expression::Id::LoadId: { + HASH(Load, bytes); + HASH(Load, signed_); + HASH(Load, offset); + HASH(Load, align); + PUSH(Load, ptr); + break; + } + case Expression::Id::StoreId: { + HASH(Store, bytes); + HASH(Store, offset); + HASH(Store, align); + HASH(Store, valueType); + PUSH(Store, ptr); + PUSH(Store, value); + break; + } + case Expression::Id::ConstId: { + HASH(Const, value.type); + HASH64(Const, value.getBits()); + break; + } + case Expression::Id::UnaryId: { + HASH(Unary, op); + PUSH(Unary, value); + break; + } + case Expression::Id::BinaryId: { + HASH(Binary, op); + PUSH(Binary, left); + PUSH(Binary, right); + break; + } + case Expression::Id::SelectId: { + PUSH(Select, ifTrue); + PUSH(Select, ifFalse); + PUSH(Select, condition); + break; + } + case Expression::Id::DropId: { + PUSH(Drop, value); + break; + } + case Expression::Id::ReturnId: { + PUSH(Return, value); + break; + } + case Expression::Id::HostId: { + HASH(Host, op); + HASH_NAME(Host, nameOperand); + HASH(Host, operands.size()); + for (Index i = 0; i < curr->cast<Host>()->operands.size(); i++) { + PUSH(Host, operands[i]); + } + break; + } + case Expression::Id::NopId: { + break; + } + case Expression::Id::UnreachableId: { + break; + } + default: WASM_UNREACHABLE(); + } + #undef HASH + #undef PUSH + } + return digest; +} +} // namespace wasm diff --git a/src/ast/ExpressionManipulator.cpp b/src/ast/ExpressionManipulator.cpp new file mode 100644 index 000000000..fb861f0d8 --- /dev/null +++ b/src/ast/ExpressionManipulator.cpp @@ -0,0 +1,152 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ast_utils.h" +#include "support/hash.h" + +namespace wasm { + +Expression* ExpressionManipulator::flexibleCopy(Expression* original, Module& wasm, CustomCopier custom) { + struct Copier : public Visitor<Copier, Expression*> { + Module& wasm; + CustomCopier custom; + + Builder builder; + + Copier(Module& wasm, CustomCopier custom) : wasm(wasm), custom(custom), builder(wasm) {} + + Expression* copy(Expression* curr) { + if (!curr) return nullptr; + auto* ret = custom(curr); + if (ret) return ret; + return Visitor<Copier, Expression*>::visit(curr); + } + + Expression* visitBlock(Block *curr) { + auto* ret = builder.makeBlock(); + for (Index i = 0; i < curr->list.size(); i++) { + ret->list.push_back(copy(curr->list[i])); + } + ret->name = curr->name; + ret->finalize(curr->type); + return ret; + } + Expression* visitIf(If *curr) { + return builder.makeIf(copy(curr->condition), copy(curr->ifTrue), copy(curr->ifFalse)); + } + Expression* visitLoop(Loop *curr) { + return builder.makeLoop(curr->name, copy(curr->body)); + } + Expression* visitBreak(Break *curr) { + return builder.makeBreak(curr->name, copy(curr->value), copy(curr->condition)); + } + Expression* visitSwitch(Switch *curr) { + return builder.makeSwitch(curr->targets, curr->default_, copy(curr->condition), copy(curr->value)); + } + Expression* visitCall(Call *curr) { + auto* ret = builder.makeCall(curr->target, {}, curr->type); + for (Index i = 0; i < curr->operands.size(); i++) { + ret->operands.push_back(copy(curr->operands[i])); + } + return ret; + } + Expression* visitCallImport(CallImport *curr) { + auto* ret = builder.makeCallImport(curr->target, {}, curr->type); + for (Index i = 0; i < curr->operands.size(); i++) { + ret->operands.push_back(copy(curr->operands[i])); + } + return ret; + } + Expression* visitCallIndirect(CallIndirect *curr) { + auto* ret = builder.makeCallIndirect(curr->fullType, copy(curr->target), {}, curr->type); + for (Index i = 0; i < curr->operands.size(); i++) { + ret->operands.push_back(copy(curr->operands[i])); + } + return ret; + } + Expression* visitGetLocal(GetLocal *curr) { + return builder.makeGetLocal(curr->index, curr->type); + } + Expression* visitSetLocal(SetLocal *curr) { + if (curr->isTee()) { + return builder.makeTeeLocal(curr->index, copy(curr->value)); + } else { + return builder.makeSetLocal(curr->index, copy(curr->value)); + } + } + Expression* visitGetGlobal(GetGlobal *curr) { + return builder.makeGetGlobal(curr->name, curr->type); + } + Expression* visitSetGlobal(SetGlobal *curr) { + return builder.makeSetGlobal(curr->name, copy(curr->value)); + } + Expression* visitLoad(Load *curr) { + return builder.makeLoad(curr->bytes, curr->signed_, curr->offset, curr->align, copy(curr->ptr), curr->type); + } + Expression* visitStore(Store *curr) { + return builder.makeStore(curr->bytes, curr->offset, curr->align, copy(curr->ptr), copy(curr->value), curr->valueType); + } + Expression* visitConst(Const *curr) { + return builder.makeConst(curr->value); + } + Expression* visitUnary(Unary *curr) { + return builder.makeUnary(curr->op, copy(curr->value)); + } + Expression* visitBinary(Binary *curr) { + return builder.makeBinary(curr->op, copy(curr->left), copy(curr->right)); + } + Expression* visitSelect(Select *curr) { + return builder.makeSelect(copy(curr->condition), copy(curr->ifTrue), copy(curr->ifFalse)); + } + Expression* visitDrop(Drop *curr) { + return builder.makeDrop(copy(curr->value)); + } + Expression* visitReturn(Return *curr) { + return builder.makeReturn(copy(curr->value)); + } + Expression* visitHost(Host *curr) { + assert(curr->operands.size() == 0); + return builder.makeHost(curr->op, curr->nameOperand, {}); + } + Expression* visitNop(Nop *curr) { + return builder.makeNop(); + } + Expression* visitUnreachable(Unreachable *curr) { + return builder.makeUnreachable(); + } + }; + + Copier copier(wasm, custom); + return copier.copy(original); +} + + +// Splice an item into the middle of a block's list +void ExpressionManipulator::spliceIntoBlock(Block* block, Index index, Expression* add) { + auto& list = block->list; + if (index == list.size()) { + list.push_back(add); // simple append + } else { + // we need to make room + list.push_back(nullptr); + for (Index i = list.size() - 1; i > index; i--) { + list[i] = list[i - 1]; + } + list[index] = add; + } +} + +} // namespace wasm diff --git a/src/ast_utils.h b/src/ast_utils.h index 159762d0b..2a151d8f7 100644 --- a/src/ast_utils.h +++ b/src/ast_utils.h @@ -17,7 +17,6 @@ #ifndef wasm_ast_utils_h #define wasm_ast_utils_h -#include "support/hash.h" #include "wasm.h" #include "wasm-traversal.h" #include "wasm-builder.h" @@ -243,202 +242,28 @@ struct ExpressionManipulator { return output; } - template<typename T> - static Expression* flexibleCopy(Expression* original, Module& wasm, T& custom) { - struct Copier : public Visitor<Copier, Expression*> { - Module& wasm; - T& custom; - - Builder builder; - - Copier(Module& wasm, T& custom) : wasm(wasm), custom(custom), builder(wasm) {} - - Expression* copy(Expression* curr) { - if (!curr) return nullptr; - auto* ret = custom.copy(curr); - if (ret) return ret; - return Visitor<Copier, Expression*>::visit(curr); - } - - Expression* visitBlock(Block *curr) { - auto* ret = builder.makeBlock(); - for (Index i = 0; i < curr->list.size(); i++) { - ret->list.push_back(copy(curr->list[i])); - } - ret->name = curr->name; - ret->finalize(curr->type); - return ret; - } - Expression* visitIf(If *curr) { - return builder.makeIf(copy(curr->condition), copy(curr->ifTrue), copy(curr->ifFalse)); - } - Expression* visitLoop(Loop *curr) { - return builder.makeLoop(curr->name, copy(curr->body)); - } - Expression* visitBreak(Break *curr) { - return builder.makeBreak(curr->name, copy(curr->value), copy(curr->condition)); - } - Expression* visitSwitch(Switch *curr) { - return builder.makeSwitch(curr->targets, curr->default_, copy(curr->condition), copy(curr->value)); - } - Expression* visitCall(Call *curr) { - auto* ret = builder.makeCall(curr->target, {}, curr->type); - for (Index i = 0; i < curr->operands.size(); i++) { - ret->operands.push_back(copy(curr->operands[i])); - } - return ret; - } - Expression* visitCallImport(CallImport *curr) { - auto* ret = builder.makeCallImport(curr->target, {}, curr->type); - for (Index i = 0; i < curr->operands.size(); i++) { - ret->operands.push_back(copy(curr->operands[i])); - } - return ret; - } - Expression* visitCallIndirect(CallIndirect *curr) { - auto* ret = builder.makeCallIndirect(curr->fullType, copy(curr->target), {}, curr->type); - for (Index i = 0; i < curr->operands.size(); i++) { - ret->operands.push_back(copy(curr->operands[i])); - } - return ret; - } - Expression* visitGetLocal(GetLocal *curr) { - return builder.makeGetLocal(curr->index, curr->type); - } - Expression* visitSetLocal(SetLocal *curr) { - if (curr->isTee()) { - return builder.makeTeeLocal(curr->index, copy(curr->value)); - } else { - return builder.makeSetLocal(curr->index, copy(curr->value)); - } - } - Expression* visitGetGlobal(GetGlobal *curr) { - return builder.makeGetGlobal(curr->name, curr->type); - } - Expression* visitSetGlobal(SetGlobal *curr) { - return builder.makeSetGlobal(curr->name, copy(curr->value)); - } - Expression* visitLoad(Load *curr) { - return builder.makeLoad(curr->bytes, curr->signed_, curr->offset, curr->align, copy(curr->ptr), curr->type); - } - Expression* visitStore(Store *curr) { - return builder.makeStore(curr->bytes, curr->offset, curr->align, copy(curr->ptr), copy(curr->value), curr->valueType); - } - Expression* visitConst(Const *curr) { - return builder.makeConst(curr->value); - } - Expression* visitUnary(Unary *curr) { - return builder.makeUnary(curr->op, copy(curr->value)); - } - Expression* visitBinary(Binary *curr) { - return builder.makeBinary(curr->op, copy(curr->left), copy(curr->right)); - } - Expression* visitSelect(Select *curr) { - return builder.makeSelect(copy(curr->condition), copy(curr->ifTrue), copy(curr->ifFalse)); - } - Expression* visitDrop(Drop *curr) { - return builder.makeDrop(copy(curr->value)); - } - Expression* visitReturn(Return *curr) { - return builder.makeReturn(copy(curr->value)); - } - Expression* visitHost(Host *curr) { - assert(curr->operands.size() == 0); - return builder.makeHost(curr->op, curr->nameOperand, {}); - } - Expression* visitNop(Nop *curr) { - return builder.makeNop(); - } - Expression* visitUnreachable(Unreachable *curr) { - return builder.makeUnreachable(); - } - }; - - Copier copier(wasm, custom); - return copier.copy(original); - } + using CustomCopier = std::function<Expression*(Expression*)>; + static Expression* flexibleCopy(Expression* original, Module& wasm, CustomCopier custom); static Expression* copy(Expression* original, Module& wasm) { - struct Copier { - Expression* copy(Expression* curr) { + auto copy = [](Expression* curr) { return nullptr; - } - } copier; - return flexibleCopy(original, wasm, copier); + }; + return flexibleCopy(original, wasm, copy); } // Splice an item into the middle of a block's list - static void spliceIntoBlock(Block* block, Index index, Expression* add) { - auto& list = block->list; - if (index == list.size()) { - list.push_back(add); // simple append - } else { - // we need to make room - list.push_back(nullptr); - for (Index i = list.size() - 1; i > index; i--) { - list[i] = list[i - 1]; - } - list[index] = add; - } - } + static void spliceIntoBlock(Block* block, Index index, Expression* add); }; struct ExpressionAnalyzer { // Given a stack of expressions, checks if the topmost is used as a result. // For example, if the parent is a block and the node is before the last position, // it is not used. - static bool isResultUsed(std::vector<Expression*> stack, Function* func) { - for (int i = int(stack.size()) - 2; i >= 0; i--) { - auto* curr = stack[i]; - auto* above = stack[i + 1]; - // only if and block can drop values (pre-drop expression was added) FIXME - if (curr->is<Block>()) { - auto* block = curr->cast<Block>(); - for (size_t j = 0; j < block->list.size() - 1; j++) { - if (block->list[j] == above) return false; - } - assert(block->list.back() == above); - // continue down - } else if (curr->is<If>()) { - auto* iff = curr->cast<If>(); - if (above == iff->condition) return true; - if (!iff->ifFalse) return false; - assert(above == iff->ifTrue || above == iff->ifFalse); - // continue down - } else { - if (curr->is<Drop>()) return false; - return true; // all other node types use the result - } - } - // The value might be used, so it depends on if the function returns - return func->result != none; - } + static bool isResultUsed(std::vector<Expression*> stack, Function* func); // Checks if a value is dropped. - static bool isResultDropped(std::vector<Expression*> stack) { - for (int i = int(stack.size()) - 2; i >= 0; i--) { - auto* curr = stack[i]; - auto* above = stack[i + 1]; - if (curr->is<Block>()) { - auto* block = curr->cast<Block>(); - for (size_t j = 0; j < block->list.size() - 1; j++) { - if (block->list[j] == above) return false; - } - assert(block->list.back() == above); - // continue down - } else if (curr->is<If>()) { - auto* iff = curr->cast<If>(); - if (above == iff->condition) return false; - if (!iff->ifFalse) return false; - assert(above == iff->ifTrue || above == iff->ifFalse); - // continue down - } else { - if (curr->is<Drop>()) return true; // dropped - return false; // all other node types use the result - } - } - return false; - } + static bool isResultDropped(std::vector<Expression*> stack); // Checks if a break is a simple - no condition, no value, just a plain branching static bool isSimple(Break* curr) { @@ -457,431 +282,18 @@ struct ExpressionAnalyzer { return false; } - template<typename T> - static bool flexibleEqual(Expression* left, Expression* right, T& comparer) { - std::vector<Name> nameStack; - std::map<Name, std::vector<Name>> rightNames; // for each name on the left, the stack of names on the right (a stack, since names are scoped and can nest duplicatively - Nop popNameMarker; - std::vector<Expression*> leftStack; - std::vector<Expression*> rightStack; - - auto noteNames = [&](Name left, Name right) { - if (left.is() != right.is()) return false; - if (left.is()) { - nameStack.push_back(left); - rightNames[left].push_back(right); - leftStack.push_back(&popNameMarker); - rightStack.push_back(&popNameMarker); - } - return true; - }; - auto checkNames = [&](Name left, Name right) { - auto iter = rightNames.find(left); - if (iter == rightNames.end()) return left == right; // non-internal name - return iter->second.back() == right; - }; - auto popName = [&]() { - auto left = nameStack.back(); - nameStack.pop_back(); - rightNames[left].pop_back(); - }; - - leftStack.push_back(left); - rightStack.push_back(right); - - while (leftStack.size() > 0 && rightStack.size() > 0) { - left = leftStack.back(); - leftStack.pop_back(); - right = rightStack.back(); - rightStack.pop_back(); - if (!left != !right) return false; - if (!left) continue; - if (left == &popNameMarker) { - popName(); - continue; - } - if (comparer.compare(left, right)) continue; // comparison hook, before all the rest - // continue with normal structural comparison - if (left->_id != right->_id) return false; - #define PUSH(clazz, what) \ - leftStack.push_back(left->cast<clazz>()->what); \ - rightStack.push_back(right->cast<clazz>()->what); - #define CHECK(clazz, what) \ - if (left->cast<clazz>()->what != right->cast<clazz>()->what) return false; - switch (left->_id) { - case Expression::Id::BlockId: { - if (!noteNames(left->cast<Block>()->name, right->cast<Block>()->name)) return false; - CHECK(Block, list.size()); - for (Index i = 0; i < left->cast<Block>()->list.size(); i++) { - PUSH(Block, list[i]); - } - break; - } - case Expression::Id::IfId: { - PUSH(If, condition); - PUSH(If, ifTrue); - PUSH(If, ifFalse); - break; - } - case Expression::Id::LoopId: { - if (!noteNames(left->cast<Loop>()->name, right->cast<Loop>()->name)) return false; - PUSH(Loop, body); - break; - } - case Expression::Id::BreakId: { - if (!checkNames(left->cast<Break>()->name, right->cast<Break>()->name)) return false; - PUSH(Break, condition); - PUSH(Break, value); - break; - } - case Expression::Id::SwitchId: { - CHECK(Switch, targets.size()); - for (Index i = 0; i < left->cast<Switch>()->targets.size(); i++) { - if (!checkNames(left->cast<Switch>()->targets[i], right->cast<Switch>()->targets[i])) return false; - } - if (!checkNames(left->cast<Switch>()->default_, right->cast<Switch>()->default_)) return false; - PUSH(Switch, condition); - PUSH(Switch, value); - break; - } - case Expression::Id::CallId: { - CHECK(Call, target); - CHECK(Call, operands.size()); - for (Index i = 0; i < left->cast<Call>()->operands.size(); i++) { - PUSH(Call, operands[i]); - } - break; - } - case Expression::Id::CallImportId: { - CHECK(CallImport, target); - CHECK(CallImport, operands.size()); - for (Index i = 0; i < left->cast<CallImport>()->operands.size(); i++) { - PUSH(CallImport, operands[i]); - } - break; - } - case Expression::Id::CallIndirectId: { - PUSH(CallIndirect, target); - CHECK(CallIndirect, fullType); - CHECK(CallIndirect, operands.size()); - for (Index i = 0; i < left->cast<CallIndirect>()->operands.size(); i++) { - PUSH(CallIndirect, operands[i]); - } - break; - } - case Expression::Id::GetLocalId: { - CHECK(GetLocal, index); - break; - } - case Expression::Id::SetLocalId: { - CHECK(SetLocal, index); - CHECK(SetLocal, type); // for tee/set - PUSH(SetLocal, value); - break; - } - case Expression::Id::GetGlobalId: { - CHECK(GetGlobal, name); - break; - } - case Expression::Id::SetGlobalId: { - CHECK(SetGlobal, name); - PUSH(SetGlobal, value); - break; - } - case Expression::Id::LoadId: { - CHECK(Load, bytes); - CHECK(Load, signed_); - CHECK(Load, offset); - CHECK(Load, align); - PUSH(Load, ptr); - break; - } - case Expression::Id::StoreId: { - CHECK(Store, bytes); - CHECK(Store, offset); - CHECK(Store, align); - CHECK(Store, valueType); - PUSH(Store, ptr); - PUSH(Store, value); - break; - } - case Expression::Id::ConstId: { - CHECK(Const, value); - break; - } - case Expression::Id::UnaryId: { - CHECK(Unary, op); - PUSH(Unary, value); - break; - } - case Expression::Id::BinaryId: { - CHECK(Binary, op); - PUSH(Binary, left); - PUSH(Binary, right); - break; - } - case Expression::Id::SelectId: { - PUSH(Select, ifTrue); - PUSH(Select, ifFalse); - PUSH(Select, condition); - break; - } - case Expression::Id::DropId: { - PUSH(Drop, value); - break; - } - case Expression::Id::ReturnId: { - PUSH(Return, value); - break; - } - case Expression::Id::HostId: { - CHECK(Host, op); - CHECK(Host, nameOperand); - CHECK(Host, operands.size()); - for (Index i = 0; i < left->cast<Host>()->operands.size(); i++) { - PUSH(Host, operands[i]); - } - break; - } - case Expression::Id::NopId: { - break; - } - case Expression::Id::UnreachableId: { - break; - } - default: WASM_UNREACHABLE(); - } - #undef CHECK - #undef PUSH - } - if (leftStack.size() > 0 || rightStack.size() > 0) return false; - return true; - } + using ExprComparer = std::function<bool(Expression*, Expression*)>; + static bool flexibleEqual(Expression* left, Expression* right, ExprComparer comparer); static bool equal(Expression* left, Expression* right) { - struct Comparer { - bool compare(Expression* left, Expression* right) { - return false; - } - } comparer; + auto comparer = [](Expression* left, Expression* right) { + return false; + }; return flexibleEqual(left, right, comparer); } // hash an expression, ignoring superficial details like specific internal names - static uint32_t hash(Expression* curr) { - uint32_t digest = 0; - - auto hash = [&digest](uint32_t hash) { - digest = rehash(digest, hash); - }; - auto hash64 = [&digest](uint64_t hash) { - digest = rehash(rehash(digest, hash >> 32), uint32_t(hash)); - }; - - std::vector<Name> nameStack; - Index internalCounter = 0; - std::map<Name, std::vector<Index>> internalNames; // for each internal name, a vector if unique ids - Nop popNameMarker; - std::vector<Expression*> stack; - - auto noteName = [&](Name curr) { - if (curr.is()) { - nameStack.push_back(curr); - internalNames[curr].push_back(internalCounter++); - stack.push_back(&popNameMarker); - } - return true; - }; - auto hashName = [&](Name curr) { - auto iter = internalNames.find(curr); - if (iter == internalNames.end()) hash64(uint64_t(curr.str)); - else hash(iter->second.back()); - }; - auto popName = [&]() { - auto curr = nameStack.back(); - nameStack.pop_back(); - internalNames[curr].pop_back(); - }; - - stack.push_back(curr); - - while (stack.size() > 0) { - curr = stack.back(); - stack.pop_back(); - if (!curr) continue; - if (curr == &popNameMarker) { - popName(); - continue; - } - hash(curr->_id); - // we often don't need to hash the type, as it is tied to other values - // we are hashing anyhow, but there are exceptions: for example, a - // get_local's type is determined by the function, so if we are - // hashing only expression fragments, then two from different - // functions may turn out the same even if the type differs. Likewise, - // if we hash between modules, then we need to take int account - // call_imports type, etc. The simplest thing is just to hash the - // type for all of them. - hash(curr->type); - - #define PUSH(clazz, what) \ - stack.push_back(curr->cast<clazz>()->what); - #define HASH(clazz, what) \ - hash(curr->cast<clazz>()->what); - #define HASH64(clazz, what) \ - hash64(curr->cast<clazz>()->what); - #define HASH_NAME(clazz, what) \ - hash64(uint64_t(curr->cast<clazz>()->what.str)); - #define HASH_PTR(clazz, what) \ - hash64(uint64_t(curr->cast<clazz>()->what)); - switch (curr->_id) { - case Expression::Id::BlockId: { - noteName(curr->cast<Block>()->name); - HASH(Block, list.size()); - for (Index i = 0; i < curr->cast<Block>()->list.size(); i++) { - PUSH(Block, list[i]); - } - break; - } - case Expression::Id::IfId: { - PUSH(If, condition); - PUSH(If, ifTrue); - PUSH(If, ifFalse); - break; - } - case Expression::Id::LoopId: { - noteName(curr->cast<Loop>()->name); - PUSH(Loop, body); - break; - } - case Expression::Id::BreakId: { - hashName(curr->cast<Break>()->name); - PUSH(Break, condition); - PUSH(Break, value); - break; - } - case Expression::Id::SwitchId: { - HASH(Switch, targets.size()); - for (Index i = 0; i < curr->cast<Switch>()->targets.size(); i++) { - hashName(curr->cast<Switch>()->targets[i]); - } - hashName(curr->cast<Switch>()->default_); - PUSH(Switch, condition); - PUSH(Switch, value); - break; - } - case Expression::Id::CallId: { - HASH_NAME(Call, target); - HASH(Call, operands.size()); - for (Index i = 0; i < curr->cast<Call>()->operands.size(); i++) { - PUSH(Call, operands[i]); - } - break; - } - case Expression::Id::CallImportId: { - HASH_NAME(CallImport, target); - HASH(CallImport, operands.size()); - for (Index i = 0; i < curr->cast<CallImport>()->operands.size(); i++) { - PUSH(CallImport, operands[i]); - } - break; - } - case Expression::Id::CallIndirectId: { - PUSH(CallIndirect, target); - HASH_NAME(CallIndirect, fullType); - HASH(CallIndirect, operands.size()); - for (Index i = 0; i < curr->cast<CallIndirect>()->operands.size(); i++) { - PUSH(CallIndirect, operands[i]); - } - break; - } - case Expression::Id::GetLocalId: { - HASH(GetLocal, index); - break; - } - case Expression::Id::SetLocalId: { - HASH(SetLocal, index); - PUSH(SetLocal, value); - break; - } - case Expression::Id::GetGlobalId: { - HASH_NAME(GetGlobal, name); - break; - } - case Expression::Id::SetGlobalId: { - HASH_NAME(SetGlobal, name); - PUSH(SetGlobal, value); - break; - } - case Expression::Id::LoadId: { - HASH(Load, bytes); - HASH(Load, signed_); - HASH(Load, offset); - HASH(Load, align); - PUSH(Load, ptr); - break; - } - case Expression::Id::StoreId: { - HASH(Store, bytes); - HASH(Store, offset); - HASH(Store, align); - HASH(Store, valueType); - PUSH(Store, ptr); - PUSH(Store, value); - break; - } - case Expression::Id::ConstId: { - HASH(Const, value.type); - HASH64(Const, value.getBits()); - break; - } - case Expression::Id::UnaryId: { - HASH(Unary, op); - PUSH(Unary, value); - break; - } - case Expression::Id::BinaryId: { - HASH(Binary, op); - PUSH(Binary, left); - PUSH(Binary, right); - break; - } - case Expression::Id::SelectId: { - PUSH(Select, ifTrue); - PUSH(Select, ifFalse); - PUSH(Select, condition); - break; - } - case Expression::Id::DropId: { - PUSH(Drop, value); - break; - } - case Expression::Id::ReturnId: { - PUSH(Return, value); - break; - } - case Expression::Id::HostId: { - HASH(Host, op); - HASH_NAME(Host, nameOperand); - HASH(Host, operands.size()); - for (Index i = 0; i < curr->cast<Host>()->operands.size(); i++) { - PUSH(Host, operands[i]); - } - break; - } - case Expression::Id::NopId: { - break; - } - case Expression::Id::UnreachableId: { - break; - } - default: WASM_UNREACHABLE(); - } - #undef HASH - #undef PUSH - } - return digest; - } + static uint32_t hash(Expression* curr); }; // Finalizes a node diff --git a/src/passes/DuplicateFunctionElimination.cpp b/src/passes/DuplicateFunctionElimination.cpp index 8e8342729..ee7fbb2a5 100644 --- a/src/passes/DuplicateFunctionElimination.cpp +++ b/src/passes/DuplicateFunctionElimination.cpp @@ -20,9 +20,10 @@ // identical when finally lowered into concrete wasm code. // -#include <wasm.h> -#include <pass.h> -#include <ast_utils.h> +#include "wasm.h" +#include "pass.h" +#include "ast_utils.h" +#include "support/hash.h" namespace wasm { @@ -181,4 +182,3 @@ Pass *createDuplicateFunctionEliminationPass() { } } // namespace wasm - diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index 04216f1ee..5c1aa5fc3 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -103,61 +103,63 @@ struct Match { bool check(Expression* seen) { // compare seen to the pattern input, doing a special operation for our "wildcards" assert(wildcards.size() == 0); - return ExpressionAnalyzer::flexibleEqual(pattern.input, seen, *this); - } - - bool compare(Expression* subInput, Expression* subSeen) { - CallImport* call = subInput->dynCast<CallImport>(); - if (!call || call->operands.size() != 1 || call->operands[0]->type != i32 || !call->operands[0]->is<Const>()) return false; - Index index = call->operands[0]->cast<Const>()->value.geti32(); - // handle our special functions - auto checkMatch = [&](WasmType type) { - if (type != none && subSeen->type != type) return false; - while (index >= wildcards.size()) { - wildcards.push_back(nullptr); - } - if (!wildcards[index]) { - // new wildcard - wildcards[index] = subSeen; // NB: no need to copy - return true; - } else { - // We are seeing this index for a second or later time, check it matches - return ExpressionAnalyzer::equal(subSeen, wildcards[index]); + auto compare = [this](Expression* subInput, Expression* subSeen) { + CallImport* call = subInput->dynCast<CallImport>(); + if (!call || call->operands.size() != 1 || call->operands[0]->type != i32 || !call->operands[0]->is<Const>()) return false; + Index index = call->operands[0]->cast<Const>()->value.geti32(); + // handle our special functions + auto checkMatch = [&](WasmType type) { + if (type != none && subSeen->type != type) return false; + while (index >= wildcards.size()) { + wildcards.push_back(nullptr); + } + if (!wildcards[index]) { + // new wildcard + wildcards[index] = subSeen; // NB: no need to copy + return true; + } else { + // We are seeing this index for a second or later time, check it matches + return ExpressionAnalyzer::equal(subSeen, wildcards[index]); + }; }; + if (call->target == I32_EXPR) { + if (checkMatch(i32)) return true; + } else if (call->target == I64_EXPR) { + if (checkMatch(i64)) return true; + } else if (call->target == F32_EXPR) { + if (checkMatch(f32)) return true; + } else if (call->target == F64_EXPR) { + if (checkMatch(f64)) return true; + } else if (call->target == ANY_EXPR) { + if (checkMatch(none)) return true; + } + return false; }; - if (call->target == I32_EXPR) { - if (checkMatch(i32)) return true; - } else if (call->target == I64_EXPR) { - if (checkMatch(i64)) return true; - } else if (call->target == F32_EXPR) { - if (checkMatch(f32)) return true; - } else if (call->target == F64_EXPR) { - if (checkMatch(f64)) return true; - } else if (call->target == ANY_EXPR) { - if (checkMatch(none)) return true; - } - return false; + + return ExpressionAnalyzer::flexibleEqual(pattern.input, seen, compare); } + // Applying/copying // Apply the match, generate an output expression from the matched input, performing substitutions as necessary Expression* apply() { - return ExpressionManipulator::flexibleCopy(pattern.output, wasm, *this); + // When copying a wildcard, perform the substitution. + // TODO: we can reuse nodes, not copying a wildcard when it appears just once, and we can reuse other individual nodes when they are discarded anyhow. + auto copy = [this](Expression* curr) -> Expression* { + CallImport* call = curr->dynCast<CallImport>(); + if (!call || call->operands.size() != 1 || call->operands[0]->type != i32 || !call->operands[0]->is<Const>()) return nullptr; + Index index = call->operands[0]->cast<Const>()->value.geti32(); + // handle our special functions + if (call->target == I32_EXPR || call->target == I64_EXPR || call->target == F32_EXPR || call->target == F64_EXPR || call->target == ANY_EXPR) { + return ExpressionManipulator::copy(wildcards.at(index), wasm); + } + return nullptr; + }; + return ExpressionManipulator::flexibleCopy(pattern.output, wasm, copy); } - // When copying a wildcard, perform the substitution. - // TODO: we can reuse nodes, not copying a wildcard when it appears just once, and we can reuse other individual nodes when they are discarded anyhow. - Expression* copy(Expression* curr) { - CallImport* call = curr->dynCast<CallImport>(); - if (!call || call->operands.size() != 1 || call->operands[0]->type != i32 || !call->operands[0]->is<Const>()) return nullptr; - Index index = call->operands[0]->cast<Const>()->value.geti32(); - // handle our special functions - if (call->target == I32_EXPR || call->target == I64_EXPR || call->target == F32_EXPR || call->target == F64_EXPR || call->target == ANY_EXPR) { - return ExpressionManipulator::copy(wildcards.at(index), wasm); - } - return nullptr; - } + }; // Main pass class |