diff options
Diffstat (limited to 'src/ir')
-rw-r--r-- | src/ir/CMakeLists.txt | 6 | ||||
-rw-r--r-- | src/ir/ExpressionAnalyzer.cpp | 558 | ||||
-rw-r--r-- | src/ir/ExpressionManipulator.cpp | 179 | ||||
-rw-r--r-- | src/ir/LocalGraph.cpp | 273 | ||||
-rw-r--r-- | src/ir/bits.h | 107 | ||||
-rw-r--r-- | src/ir/block-utils.h | 67 | ||||
-rw-r--r-- | src/ir/branch-utils.h | 183 | ||||
-rw-r--r-- | src/ir/cost.h | 255 | ||||
-rw-r--r-- | src/ir/count.h | 50 | ||||
-rw-r--r-- | src/ir/effects.h | 278 | ||||
-rw-r--r-- | src/ir/find_all.h | 48 | ||||
-rw-r--r-- | src/ir/global-utils.h | 55 | ||||
-rw-r--r-- | src/ir/hashed.h | 59 | ||||
-rw-r--r-- | src/ir/import-utils.h | 41 | ||||
-rw-r--r-- | src/ir/label-utils.h | 62 | ||||
-rw-r--r-- | src/ir/literal-utils.h | 56 | ||||
-rw-r--r-- | src/ir/load-utils.h | 40 | ||||
-rw-r--r-- | src/ir/local-graph.h | 111 | ||||
-rw-r--r-- | src/ir/localize.h | 47 | ||||
-rw-r--r-- | src/ir/manipulation.h | 69 | ||||
-rw-r--r-- | src/ir/memory-utils.h | 56 | ||||
-rw-r--r-- | src/ir/module-utils.h | 59 | ||||
-rw-r--r-- | src/ir/properties.h | 141 | ||||
-rw-r--r-- | src/ir/trapping.h | 120 | ||||
-rw-r--r-- | src/ir/type-updating.h | 286 | ||||
-rw-r--r-- | src/ir/utils.h | 360 |
26 files changed, 3566 insertions, 0 deletions
diff --git a/src/ir/CMakeLists.txt b/src/ir/CMakeLists.txt new file mode 100644 index 000000000..607207968 --- /dev/null +++ b/src/ir/CMakeLists.txt @@ -0,0 +1,6 @@ +SET(ir_SOURCES + ExpressionAnalyzer.cpp + ExpressionManipulator.cpp + LocalGraph.cpp +) +ADD_LIBRARY(ir STATIC ${ir_SOURCES}) diff --git a/src/ir/ExpressionAnalyzer.cpp b/src/ir/ExpressionAnalyzer.cpp new file mode 100644 index 000000000..05450d567 --- /dev/null +++ b/src/ir/ExpressionAnalyzer.cpp @@ -0,0 +1,558 @@ +/* + * Copyright 2016 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "support/hash.h" +#include "ir/utils.h" +#include "ir/load-utils.h" + +namespace wasm { +// Given a stack of expressions, checks if the topmost is used as a result. +// For example, if the parent is a block and the node is before the last position, +// it is not used. +bool ExpressionAnalyzer::isResultUsed(std::vector<Expression*> stack, Function* func) { + for (int i = int(stack.size()) - 2; i >= 0; i--) { + auto* curr = stack[i]; + auto* above = stack[i + 1]; + // only if and block can drop values (pre-drop expression was added) FIXME + if (curr->is<Block>()) { + auto* block = curr->cast<Block>(); + for (size_t j = 0; j < block->list.size() - 1; j++) { + if (block->list[j] == above) return false; + } + assert(block->list.back() == above); + // continue down + } else if (curr->is<If>()) { + auto* iff = curr->cast<If>(); + if (above == iff->condition) return true; + if (!iff->ifFalse) return false; + assert(above == iff->ifTrue || above == iff->ifFalse); + // continue down + } else { + if (curr->is<Drop>()) return false; + return true; // all other node types use the result + } + } + // The value might be used, so it depends on if the function returns + return func->result != none; +} + +// Checks if a value is dropped. +bool ExpressionAnalyzer::isResultDropped(std::vector<Expression*> stack) { + for (int i = int(stack.size()) - 2; i >= 0; i--) { + auto* curr = stack[i]; + auto* above = stack[i + 1]; + if (curr->is<Block>()) { + auto* block = curr->cast<Block>(); + for (size_t j = 0; j < block->list.size() - 1; j++) { + if (block->list[j] == above) return false; + } + assert(block->list.back() == above); + // continue down + } else if (curr->is<If>()) { + auto* iff = curr->cast<If>(); + if (above == iff->condition) return false; + if (!iff->ifFalse) return false; + assert(above == iff->ifTrue || above == iff->ifFalse); + // continue down + } else { + if (curr->is<Drop>()) return true; // dropped + return false; // all other node types use the result + } + } + return false; +} + + +bool ExpressionAnalyzer::flexibleEqual(Expression* left, Expression* right, ExprComparer comparer) { + std::vector<Name> nameStack; + std::map<Name, std::vector<Name>> rightNames; // for each name on the left, the stack of names on the right (a stack, since names are scoped and can nest duplicatively + Nop popNameMarker; + std::vector<Expression*> leftStack; + std::vector<Expression*> rightStack; + + auto noteNames = [&](Name left, Name right) { + if (left.is() != right.is()) return false; + if (left.is()) { + nameStack.push_back(left); + rightNames[left].push_back(right); + leftStack.push_back(&popNameMarker); + rightStack.push_back(&popNameMarker); + } + return true; + }; + auto checkNames = [&](Name left, Name right) { + auto iter = rightNames.find(left); + if (iter == rightNames.end()) return left == right; // non-internal name + return iter->second.back() == right; + }; + auto popName = [&]() { + auto left = nameStack.back(); + nameStack.pop_back(); + rightNames[left].pop_back(); + }; + + leftStack.push_back(left); + rightStack.push_back(right); + + while (leftStack.size() > 0 && rightStack.size() > 0) { + left = leftStack.back(); + leftStack.pop_back(); + right = rightStack.back(); + rightStack.pop_back(); + if (!left != !right) return false; + if (!left) continue; + if (left == &popNameMarker) { + popName(); + continue; + } + if (comparer(left, right)) continue; // comparison hook, before all the rest + // continue with normal structural comparison + if (left->_id != right->_id) return false; + #define PUSH(clazz, what) \ + leftStack.push_back(left->cast<clazz>()->what); \ + rightStack.push_back(right->cast<clazz>()->what); + #define CHECK(clazz, what) \ + if (left->cast<clazz>()->what != right->cast<clazz>()->what) return false; + switch (left->_id) { + case Expression::Id::BlockId: { + if (!noteNames(left->cast<Block>()->name, right->cast<Block>()->name)) return false; + CHECK(Block, list.size()); + for (Index i = 0; i < left->cast<Block>()->list.size(); i++) { + PUSH(Block, list[i]); + } + break; + } + case Expression::Id::IfId: { + PUSH(If, condition); + PUSH(If, ifTrue); + PUSH(If, ifFalse); + break; + } + case Expression::Id::LoopId: { + if (!noteNames(left->cast<Loop>()->name, right->cast<Loop>()->name)) return false; + PUSH(Loop, body); + break; + } + case Expression::Id::BreakId: { + if (!checkNames(left->cast<Break>()->name, right->cast<Break>()->name)) return false; + PUSH(Break, condition); + PUSH(Break, value); + break; + } + case Expression::Id::SwitchId: { + CHECK(Switch, targets.size()); + for (Index i = 0; i < left->cast<Switch>()->targets.size(); i++) { + if (!checkNames(left->cast<Switch>()->targets[i], right->cast<Switch>()->targets[i])) return false; + } + if (!checkNames(left->cast<Switch>()->default_, right->cast<Switch>()->default_)) return false; + PUSH(Switch, condition); + PUSH(Switch, value); + break; + } + case Expression::Id::CallId: { + CHECK(Call, target); + CHECK(Call, operands.size()); + for (Index i = 0; i < left->cast<Call>()->operands.size(); i++) { + PUSH(Call, operands[i]); + } + break; + } + case Expression::Id::CallImportId: { + CHECK(CallImport, target); + CHECK(CallImport, operands.size()); + for (Index i = 0; i < left->cast<CallImport>()->operands.size(); i++) { + PUSH(CallImport, operands[i]); + } + break; + } + case Expression::Id::CallIndirectId: { + PUSH(CallIndirect, target); + CHECK(CallIndirect, fullType); + CHECK(CallIndirect, operands.size()); + for (Index i = 0; i < left->cast<CallIndirect>()->operands.size(); i++) { + PUSH(CallIndirect, operands[i]); + } + break; + } + case Expression::Id::GetLocalId: { + CHECK(GetLocal, index); + break; + } + case Expression::Id::SetLocalId: { + CHECK(SetLocal, index); + CHECK(SetLocal, type); // for tee/set + PUSH(SetLocal, value); + break; + } + case Expression::Id::GetGlobalId: { + CHECK(GetGlobal, name); + break; + } + case Expression::Id::SetGlobalId: { + CHECK(SetGlobal, name); + PUSH(SetGlobal, value); + break; + } + case Expression::Id::LoadId: { + CHECK(Load, bytes); + if (LoadUtils::isSignRelevant(left->cast<Load>()) && + LoadUtils::isSignRelevant(right->cast<Load>())) { + CHECK(Load, signed_); + } + CHECK(Load, offset); + CHECK(Load, align); + PUSH(Load, ptr); + break; + } + case Expression::Id::StoreId: { + CHECK(Store, bytes); + CHECK(Store, offset); + CHECK(Store, align); + CHECK(Store, valueType); + PUSH(Store, ptr); + PUSH(Store, value); + break; + } + case Expression::Id::AtomicCmpxchgId: { + CHECK(AtomicCmpxchg, bytes); + CHECK(AtomicCmpxchg, offset); + PUSH(AtomicCmpxchg, ptr); + PUSH(AtomicCmpxchg, expected); + PUSH(AtomicCmpxchg, replacement); + break; + } + case Expression::Id::AtomicRMWId: { + CHECK(AtomicRMW, op); + CHECK(AtomicRMW, bytes); + CHECK(AtomicRMW, offset); + PUSH(AtomicRMW, ptr); + PUSH(AtomicRMW, value); + break; + } + case Expression::Id::AtomicWaitId: { + CHECK(AtomicWait, expectedType); + PUSH(AtomicWait, ptr); + PUSH(AtomicWait, expected); + PUSH(AtomicWait, timeout); + break; + } + case Expression::Id::AtomicWakeId: { + PUSH(AtomicWake, ptr); + PUSH(AtomicWake, wakeCount); + break; + } + case Expression::Id::ConstId: { + if (!left->cast<Const>()->value.bitwiseEqual(right->cast<Const>()->value)) { + return false; + } + break; + } + case Expression::Id::UnaryId: { + CHECK(Unary, op); + PUSH(Unary, value); + break; + } + case Expression::Id::BinaryId: { + CHECK(Binary, op); + PUSH(Binary, left); + PUSH(Binary, right); + break; + } + case Expression::Id::SelectId: { + PUSH(Select, ifTrue); + PUSH(Select, ifFalse); + PUSH(Select, condition); + break; + } + case Expression::Id::DropId: { + PUSH(Drop, value); + break; + } + case Expression::Id::ReturnId: { + PUSH(Return, value); + break; + } + case Expression::Id::HostId: { + CHECK(Host, op); + CHECK(Host, nameOperand); + CHECK(Host, operands.size()); + for (Index i = 0; i < left->cast<Host>()->operands.size(); i++) { + PUSH(Host, operands[i]); + } + break; + } + case Expression::Id::NopId: { + break; + } + case Expression::Id::UnreachableId: { + break; + } + default: WASM_UNREACHABLE(); + } + #undef CHECK + #undef PUSH + } + if (leftStack.size() > 0 || rightStack.size() > 0) return false; + return true; +} + + +// hash an expression, ignoring superficial details like specific internal names +uint32_t ExpressionAnalyzer::hash(Expression* curr) { + uint32_t digest = 0; + + auto hash = [&digest](uint32_t hash) { + digest = rehash(digest, hash); + }; + auto hash64 = [&digest](uint64_t hash) { + digest = rehash(rehash(digest, uint32_t(hash >> 32)), uint32_t(hash)); + }; + + std::vector<Name> nameStack; + Index internalCounter = 0; + std::map<Name, std::vector<Index>> internalNames; // for each internal name, a vector if unique ids + Nop popNameMarker; + std::vector<Expression*> stack; + + auto noteName = [&](Name curr) { + if (curr.is()) { + nameStack.push_back(curr); + internalNames[curr].push_back(internalCounter++); + stack.push_back(&popNameMarker); + } + return true; + }; + auto hashName = [&](Name curr) { + auto iter = internalNames.find(curr); + if (iter == internalNames.end()) hash64(uint64_t(curr.str)); + else hash(iter->second.back()); + }; + auto popName = [&]() { + auto curr = nameStack.back(); + nameStack.pop_back(); + internalNames[curr].pop_back(); + }; + + stack.push_back(curr); + + while (stack.size() > 0) { + curr = stack.back(); + stack.pop_back(); + if (!curr) continue; + if (curr == &popNameMarker) { + popName(); + continue; + } + hash(curr->_id); + // we often don't need to hash the type, as it is tied to other values + // we are hashing anyhow, but there are exceptions: for example, a + // get_local's type is determined by the function, so if we are + // hashing only expression fragments, then two from different + // functions may turn out the same even if the type differs. Likewise, + // if we hash between modules, then we need to take int account + // call_imports type, etc. The simplest thing is just to hash the + // type for all of them. + hash(curr->type); + + #define PUSH(clazz, what) \ + stack.push_back(curr->cast<clazz>()->what); + #define HASH(clazz, what) \ + hash(curr->cast<clazz>()->what); + #define HASH64(clazz, what) \ + hash64(curr->cast<clazz>()->what); + #define HASH_NAME(clazz, what) \ + hash64(uint64_t(curr->cast<clazz>()->what.str)); + #define HASH_PTR(clazz, what) \ + hash64(uint64_t(curr->cast<clazz>()->what)); + switch (curr->_id) { + case Expression::Id::BlockId: { + noteName(curr->cast<Block>()->name); + HASH(Block, list.size()); + for (Index i = 0; i < curr->cast<Block>()->list.size(); i++) { + PUSH(Block, list[i]); + } + break; + } + case Expression::Id::IfId: { + PUSH(If, condition); + PUSH(If, ifTrue); + PUSH(If, ifFalse); + break; + } + case Expression::Id::LoopId: { + noteName(curr->cast<Loop>()->name); + PUSH(Loop, body); + break; + } + case Expression::Id::BreakId: { + hashName(curr->cast<Break>()->name); + PUSH(Break, condition); + PUSH(Break, value); + break; + } + case Expression::Id::SwitchId: { + HASH(Switch, targets.size()); + for (Index i = 0; i < curr->cast<Switch>()->targets.size(); i++) { + hashName(curr->cast<Switch>()->targets[i]); + } + hashName(curr->cast<Switch>()->default_); + PUSH(Switch, condition); + PUSH(Switch, value); + break; + } + case Expression::Id::CallId: { + HASH_NAME(Call, target); + HASH(Call, operands.size()); + for (Index i = 0; i < curr->cast<Call>()->operands.size(); i++) { + PUSH(Call, operands[i]); + } + break; + } + case Expression::Id::CallImportId: { + HASH_NAME(CallImport, target); + HASH(CallImport, operands.size()); + for (Index i = 0; i < curr->cast<CallImport>()->operands.size(); i++) { + PUSH(CallImport, operands[i]); + } + break; + } + case Expression::Id::CallIndirectId: { + PUSH(CallIndirect, target); + HASH_NAME(CallIndirect, fullType); + HASH(CallIndirect, operands.size()); + for (Index i = 0; i < curr->cast<CallIndirect>()->operands.size(); i++) { + PUSH(CallIndirect, operands[i]); + } + break; + } + case Expression::Id::GetLocalId: { + HASH(GetLocal, index); + break; + } + case Expression::Id::SetLocalId: { + HASH(SetLocal, index); + PUSH(SetLocal, value); + break; + } + case Expression::Id::GetGlobalId: { + HASH_NAME(GetGlobal, name); + break; + } + case Expression::Id::SetGlobalId: { + HASH_NAME(SetGlobal, name); + PUSH(SetGlobal, value); + break; + } + case Expression::Id::LoadId: { + HASH(Load, bytes); + if (LoadUtils::isSignRelevant(curr->cast<Load>())) { + HASH(Load, signed_); + } + HASH(Load, offset); + HASH(Load, align); + PUSH(Load, ptr); + break; + } + case Expression::Id::StoreId: { + HASH(Store, bytes); + HASH(Store, offset); + HASH(Store, align); + HASH(Store, valueType); + PUSH(Store, ptr); + PUSH(Store, value); + break; + } + case Expression::Id::AtomicCmpxchgId: { + HASH(AtomicCmpxchg, bytes); + HASH(AtomicCmpxchg, offset); + PUSH(AtomicCmpxchg, ptr); + PUSH(AtomicCmpxchg, expected); + PUSH(AtomicCmpxchg, replacement); + break; + } + case Expression::Id::AtomicRMWId: { + HASH(AtomicRMW, op); + HASH(AtomicRMW, bytes); + HASH(AtomicRMW, offset); + PUSH(AtomicRMW, ptr); + PUSH(AtomicRMW, value); + break; + } + case Expression::Id::AtomicWaitId: { + HASH(AtomicWait, expectedType); + PUSH(AtomicWait, ptr); + PUSH(AtomicWait, expected); + PUSH(AtomicWait, timeout); + break; + } + case Expression::Id::AtomicWakeId: { + PUSH(AtomicWake, ptr); + PUSH(AtomicWake, wakeCount); + break; + } + case Expression::Id::ConstId: { + HASH(Const, value.type); + HASH64(Const, value.getBits()); + break; + } + case Expression::Id::UnaryId: { + HASH(Unary, op); + PUSH(Unary, value); + break; + } + case Expression::Id::BinaryId: { + HASH(Binary, op); + PUSH(Binary, left); + PUSH(Binary, right); + break; + } + case Expression::Id::SelectId: { + PUSH(Select, ifTrue); + PUSH(Select, ifFalse); + PUSH(Select, condition); + break; + } + case Expression::Id::DropId: { + PUSH(Drop, value); + break; + } + case Expression::Id::ReturnId: { + PUSH(Return, value); + break; + } + case Expression::Id::HostId: { + HASH(Host, op); + HASH_NAME(Host, nameOperand); + HASH(Host, operands.size()); + for (Index i = 0; i < curr->cast<Host>()->operands.size(); i++) { + PUSH(Host, operands[i]); + } + break; + } + case Expression::Id::NopId: { + break; + } + case Expression::Id::UnreachableId: { + break; + } + default: WASM_UNREACHABLE(); + } + #undef HASH + #undef PUSH + } + return digest; +} +} // namespace wasm diff --git a/src/ir/ExpressionManipulator.cpp b/src/ir/ExpressionManipulator.cpp new file mode 100644 index 000000000..aa2a10388 --- /dev/null +++ b/src/ir/ExpressionManipulator.cpp @@ -0,0 +1,179 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/utils.h" +#include "support/hash.h" + +namespace wasm { + +namespace ExpressionManipulator { + +Expression* flexibleCopy(Expression* original, Module& wasm, CustomCopier custom) { + struct Copier : public Visitor<Copier, Expression*> { + Module& wasm; + CustomCopier custom; + + Builder builder; + + Copier(Module& wasm, CustomCopier custom) : wasm(wasm), custom(custom), builder(wasm) {} + + Expression* copy(Expression* curr) { + if (!curr) return nullptr; + auto* ret = custom(curr); + if (ret) return ret; + return Visitor<Copier, Expression*>::visit(curr); + } + + Expression* visitBlock(Block *curr) { + auto* ret = builder.makeBlock(); + for (Index i = 0; i < curr->list.size(); i++) { + ret->list.push_back(copy(curr->list[i])); + } + ret->name = curr->name; + ret->finalize(curr->type); + return ret; + } + Expression* visitIf(If *curr) { + return builder.makeIf(copy(curr->condition), copy(curr->ifTrue), copy(curr->ifFalse)); + } + Expression* visitLoop(Loop *curr) { + return builder.makeLoop(curr->name, copy(curr->body)); + } + Expression* visitBreak(Break *curr) { + return builder.makeBreak(curr->name, copy(curr->value), copy(curr->condition)); + } + Expression* visitSwitch(Switch *curr) { + return builder.makeSwitch(curr->targets, curr->default_, copy(curr->condition), copy(curr->value)); + } + Expression* visitCall(Call *curr) { + auto* ret = builder.makeCall(curr->target, {}, curr->type); + for (Index i = 0; i < curr->operands.size(); i++) { + ret->operands.push_back(copy(curr->operands[i])); + } + return ret; + } + Expression* visitCallImport(CallImport *curr) { + auto* ret = builder.makeCallImport(curr->target, {}, curr->type); + for (Index i = 0; i < curr->operands.size(); i++) { + ret->operands.push_back(copy(curr->operands[i])); + } + return ret; + } + Expression* visitCallIndirect(CallIndirect *curr) { + auto* ret = builder.makeCallIndirect(curr->fullType, copy(curr->target), {}, curr->type); + for (Index i = 0; i < curr->operands.size(); i++) { + ret->operands.push_back(copy(curr->operands[i])); + } + return ret; + } + Expression* visitGetLocal(GetLocal *curr) { + return builder.makeGetLocal(curr->index, curr->type); + } + Expression* visitSetLocal(SetLocal *curr) { + if (curr->isTee()) { + return builder.makeTeeLocal(curr->index, copy(curr->value)); + } else { + return builder.makeSetLocal(curr->index, copy(curr->value)); + } + } + Expression* visitGetGlobal(GetGlobal *curr) { + return builder.makeGetGlobal(curr->name, curr->type); + } + Expression* visitSetGlobal(SetGlobal *curr) { + return builder.makeSetGlobal(curr->name, copy(curr->value)); + } + Expression* visitLoad(Load *curr) { + if (curr->isAtomic) { + return builder.makeAtomicLoad(curr->bytes, curr->offset, + copy(curr->ptr), curr->type); + } + return builder.makeLoad(curr->bytes, curr->signed_, curr->offset, curr->align, copy(curr->ptr), curr->type); + } + Expression* visitStore(Store *curr) { + if (curr->isAtomic) { + return builder.makeAtomicStore(curr->bytes, curr->offset, copy(curr->ptr), copy(curr->value), curr->valueType); + } + return builder.makeStore(curr->bytes, curr->offset, curr->align, copy(curr->ptr), copy(curr->value), curr->valueType); + } + Expression* visitAtomicRMW(AtomicRMW* curr) { + return builder.makeAtomicRMW(curr->op, curr->bytes, curr->offset, + copy(curr->ptr), copy(curr->value), curr->type); + } + Expression* visitAtomicCmpxchg(AtomicCmpxchg* curr) { + return builder.makeAtomicCmpxchg(curr->bytes, curr->offset, + copy(curr->ptr), copy(curr->expected), copy(curr->replacement), + curr->type); + } + Expression* visitAtomicWait(AtomicWait* curr) { + return builder.makeAtomicWait(copy(curr->ptr), copy(curr->expected), copy(curr->timeout), curr->expectedType); + } + Expression* visitAtomicWake(AtomicWake* curr) { + return builder.makeAtomicWake(copy(curr->ptr), copy(curr->wakeCount)); + } + Expression* visitConst(Const *curr) { + return builder.makeConst(curr->value); + } + Expression* visitUnary(Unary *curr) { + return builder.makeUnary(curr->op, copy(curr->value)); + } + Expression* visitBinary(Binary *curr) { + return builder.makeBinary(curr->op, copy(curr->left), copy(curr->right)); + } + Expression* visitSelect(Select *curr) { + return builder.makeSelect(copy(curr->condition), copy(curr->ifTrue), copy(curr->ifFalse)); + } + Expression* visitDrop(Drop *curr) { + return builder.makeDrop(copy(curr->value)); + } + Expression* visitReturn(Return *curr) { + return builder.makeReturn(copy(curr->value)); + } + Expression* visitHost(Host *curr) { + assert(curr->operands.size() == 0); + return builder.makeHost(curr->op, curr->nameOperand, {}); + } + Expression* visitNop(Nop *curr) { + return builder.makeNop(); + } + Expression* visitUnreachable(Unreachable *curr) { + return builder.makeUnreachable(); + } + }; + + Copier copier(wasm, custom); + return copier.copy(original); +} + + +// Splice an item into the middle of a block's list +void spliceIntoBlock(Block* block, Index index, Expression* add) { + auto& list = block->list; + if (index == list.size()) { + list.push_back(add); // simple append + } else { + // we need to make room + list.push_back(nullptr); + for (Index i = list.size() - 1; i > index; i--) { + list[i] = list[i - 1]; + } + list[index] = add; + } + block->finalize(block->type); +} + +} // namespace ExpressionManipulator + +} // namespace wasm diff --git a/src/ir/LocalGraph.cpp b/src/ir/LocalGraph.cpp new file mode 100644 index 000000000..cee187c6d --- /dev/null +++ b/src/ir/LocalGraph.cpp @@ -0,0 +1,273 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <iterator> + +#include <wasm-builder.h> +#include <wasm-printing.h> +#include <ir/find_all.h> +#include <ir/local-graph.h> + +namespace wasm { + +LocalGraph::LocalGraph(Function* func, Module* module) { + walkFunctionInModule(func, module); + +#ifdef LOCAL_GRAPH_DEBUG + std::cout << "LocalGraph::dump\n"; + for (auto& pair : getSetses) { + auto* get = pair.first; + auto& sets = pair.second; + std::cout << "GET\n" << get << " is influenced by\n"; + for (auto* set : sets) { + std::cout << set << '\n'; + } + } +#endif +} + +void LocalGraph::computeInfluences() { + for (auto& pair : locations) { + auto* curr = pair.first; + if (auto* set = curr->dynCast<SetLocal>()) { + FindAll<GetLocal> findAll(set->value); + for (auto* get : findAll.list) { + getInfluences[get].insert(set); + } + } else { + auto* get = curr->cast<GetLocal>(); + for (auto* set : getSetses[get]) { + setInfluences[set].insert(get); + } + } + } +} + +void LocalGraph::doWalkFunction(Function* func) { + numLocals = func->getNumLocals(); + if (numLocals == 0) return; // nothing to do + // We begin with each param being assigned from the incoming value, and the zero-init for the locals, + // so the initial state is the identity permutation + currMapping.resize(numLocals); + for (auto& set : currMapping) { + set = { nullptr }; + } + PostWalker<LocalGraph>::walk(func->body); +} + +// control flow + +void LocalGraph::visitBlock(Block* curr) { + if (curr->name.is() && breakMappings.find(curr->name) != breakMappings.end()) { + auto& infos = breakMappings[curr->name]; + infos.emplace_back(std::move(currMapping)); + currMapping = std::move(merge(infos)); + breakMappings.erase(curr->name); + } +} + +void LocalGraph::finishIf() { + // that's it for this if, merge + std::vector<Mapping> breaks; + breaks.emplace_back(std::move(currMapping)); + breaks.emplace_back(std::move(mappingStack.back())); + mappingStack.pop_back(); + currMapping = std::move(merge(breaks)); +} + +void LocalGraph::afterIfCondition(LocalGraph* self, Expression** currp) { + self->mappingStack.push_back(self->currMapping); +} +void LocalGraph::afterIfTrue(LocalGraph* self, Expression** currp) { + auto* curr = (*currp)->cast<If>(); + if (curr->ifFalse) { + auto afterCondition = std::move(self->mappingStack.back()); + self->mappingStack.back() = std::move(self->currMapping); + self->currMapping = std::move(afterCondition); + } else { + self->finishIf(); + } +} +void LocalGraph::afterIfFalse(LocalGraph* self, Expression** currp) { + self->finishIf(); +} +void LocalGraph::beforeLoop(LocalGraph* self, Expression** currp) { + // save the state before entering the loop, for calculation later of the merge at the loop top + self->mappingStack.push_back(self->currMapping); + self->loopGetStack.push_back({}); +} +void LocalGraph::visitLoop(Loop* curr) { + if (curr->name.is() && breakMappings.find(curr->name) != breakMappings.end()) { + auto& infos = breakMappings[curr->name]; + infos.emplace_back(std::move(mappingStack.back())); + auto before = infos.back(); + auto& merged = merge(infos); + // every local we created a phi for requires us to update get_local operations in + // the loop - the branch back has means that gets in the loop have potentially + // more sets reaching them. + // we can detect this as follows: if a get of oldIndex has the same sets + // as the sets at the entrance to the loop, then it is affected by the loop + // header sets, and we can add to there sets that looped back + auto linkLoopTop = [&](Index i, Sets& getSets) { + auto& beforeSets = before[i]; + if (getSets.size() < beforeSets.size()) { + // the get trivially has fewer sets, so it overrode the loop entry sets + return; + } + std::vector<SetLocal*> intersection; + std::set_intersection(beforeSets.begin(), beforeSets.end(), + getSets.begin(), getSets.end(), + std::back_inserter(intersection)); + if (intersection.size() < beforeSets.size()) { + // the get has not the same sets as in the loop entry + return; + } + // the get has the entry sets, so add any new ones + for (auto* set : merged[i]) { + getSets.insert(set); + } + }; + auto& gets = loopGetStack.back(); + for (auto* get : gets) { + linkLoopTop(get->index, getSetses[get]); + } + // and the same for the loop fallthrough: any local that still has the + // entry sets should also have the loop-back sets as well + for (Index i = 0; i < numLocals; i++) { + linkLoopTop(i, currMapping[i]); + } + // finally, breaks still in flight must be updated too + for (auto& iter : breakMappings) { + auto name = iter.first; + if (name == curr->name) continue; // skip our own (which is still in use) + auto& mappings = iter.second; + for (auto& mapping : mappings) { + for (Index i = 0; i < numLocals; i++) { + linkLoopTop(i, mapping[i]); + } + } + } + // now that we are done with using the mappings, erase our own + breakMappings.erase(curr->name); + } + mappingStack.pop_back(); + loopGetStack.pop_back(); +} +void LocalGraph::visitBreak(Break* curr) { + if (curr->condition) { + breakMappings[curr->name].emplace_back(currMapping); + } else { + breakMappings[curr->name].emplace_back(std::move(currMapping)); + setUnreachable(currMapping); + } +} +void LocalGraph::visitSwitch(Switch* curr) { + std::set<Name> all; + for (auto target : curr->targets) { + all.insert(target); + } + all.insert(curr->default_); + for (auto target : all) { + breakMappings[target].emplace_back(currMapping); + } + setUnreachable(currMapping); +} +void LocalGraph::visitReturn(Return *curr) { + setUnreachable(currMapping); +} +void LocalGraph::visitUnreachable(Unreachable *curr) { + setUnreachable(currMapping); +} + +// local usage + +void LocalGraph::visitGetLocal(GetLocal* curr) { + assert(currMapping.size() == numLocals); + assert(curr->index < numLocals); + for (auto& loopGets : loopGetStack) { + loopGets.push_back(curr); + } + // current sets are our sets + getSetses[curr] = currMapping[curr->index]; + locations[curr] = getCurrentPointer(); +} +void LocalGraph::visitSetLocal(SetLocal* curr) { + assert(currMapping.size() == numLocals); + assert(curr->index < numLocals); + // current sets are just this set + currMapping[curr->index] = { curr }; // TODO optimize? + locations[curr] = getCurrentPointer(); +} + +// traversal + +void LocalGraph::scan(LocalGraph* self, Expression** currp) { + if (auto* iff = (*currp)->dynCast<If>()) { + // if needs special handling + if (iff->ifFalse) { + self->pushTask(LocalGraph::afterIfFalse, currp); + self->pushTask(LocalGraph::scan, &iff->ifFalse); + } + self->pushTask(LocalGraph::afterIfTrue, currp); + self->pushTask(LocalGraph::scan, &iff->ifTrue); + self->pushTask(LocalGraph::afterIfCondition, currp); + self->pushTask(LocalGraph::scan, &iff->condition); + } else { + PostWalker<LocalGraph>::scan(self, currp); + } + + // loops need pre-order visiting too + if ((*currp)->is<Loop>()) { + self->pushTask(LocalGraph::beforeLoop, currp); + } +} + +// helpers + +void LocalGraph::setUnreachable(Mapping& mapping) { + mapping.resize(numLocals); // may have been emptied by a move + mapping[0].clear(); +} + +bool LocalGraph::isUnreachable(Mapping& mapping) { + // we must have some set for each index, if only the zero init, so empty means we emptied it for unreachable code + return mapping[0].empty(); +} + +// merges a bunch of infos into one. +// if we need phis, writes them into the provided vector. the caller should +// ensure those are placed in the right location +LocalGraph::Mapping& LocalGraph::merge(std::vector<Mapping>& mappings) { + assert(mappings.size() > 0); + auto& out = mappings[0]; + if (mappings.size() == 1) { + return out; + } + // merge into the first + for (Index j = 1; j < mappings.size(); j++) { + auto& other = mappings[j]; + for (Index i = 0; i < numLocals; i++) { + auto& outSets = out[i]; + for (auto* set : other[i]) { + outSets.insert(set); + } + } + } + return out; +} + +} // namespace wasm + diff --git a/src/ir/bits.h b/src/ir/bits.h new file mode 100644 index 000000000..4196b74c1 --- /dev/null +++ b/src/ir/bits.h @@ -0,0 +1,107 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ir_bits_h +#define wasm_ir_bits_h + +#include "support/bits.h" +#include "wasm-builder.h" +#include "ir/literal-utils.h" + +namespace wasm { + +struct Bits { + // get a mask to keep only the low # of bits + static int32_t lowBitMask(int32_t bits) { + uint32_t ret = -1; + if (bits >= 32) return ret; + return ret >> (32 - bits); + } + + // checks if the input is a mask of lower bits, i.e., all 1s up to some high bit, and all zeros + // from there. returns the number of masked bits, or 0 if this is not such a mask + static uint32_t getMaskedBits(uint32_t mask) { + if (mask == uint32_t(-1)) return 32; // all the bits + if (mask == 0) return 0; // trivially not a mask + // otherwise, see if adding one turns this into a 1-bit thing, 00011111 + 1 => 00100000 + if (PopCount(mask + 1) != 1) return 0; + // this is indeed a mask + return 32 - CountLeadingZeroes(mask); + } + + // gets the number of effective shifts a shift operation does. In + // wasm, only 5 bits matter for 32-bit shifts, and 6 for 64. + static Index getEffectiveShifts(Index amount, WasmType type) { + if (type == i32) { + return amount & 31; + } else if (type == i64) { + return amount & 63; + } + WASM_UNREACHABLE(); + } + + static Index getEffectiveShifts(Expression* expr) { + auto* amount = expr->cast<Const>(); + if (amount->type == i32) { + return getEffectiveShifts(amount->value.geti32(), i32); + } else if (amount->type == i64) { + return getEffectiveShifts(amount->value.geti64(), i64); + } + WASM_UNREACHABLE(); + } + + static Expression* makeSignExt(Expression* value, Index bytes, Module& wasm) { + if (value->type == i32) { + if (bytes == 1 || bytes == 2) { + auto shifts = bytes == 1 ? 24 : 16; + Builder builder(wasm); + return builder.makeBinary( + ShrSInt32, + builder.makeBinary( + ShlInt32, + value, + LiteralUtils::makeFromInt32(shifts, i32, wasm) + ), + LiteralUtils::makeFromInt32(shifts, i32, wasm) + ); + } + assert(bytes == 4); + return value; // nothing to do + } else { + assert(value->type == i64); + if (bytes == 1 || bytes == 2 || bytes == 4) { + auto shifts = bytes == 1 ? 56 : (bytes == 2 ? 48 : 32); + Builder builder(wasm); + return builder.makeBinary( + ShrSInt64, + builder.makeBinary( + ShlInt64, + value, + LiteralUtils::makeFromInt32(shifts, i64, wasm) + ), + LiteralUtils::makeFromInt32(shifts, i64, wasm) + ); + } + assert(bytes == 8); + return value; // nothing to do + } + } +}; + +} // namespace wasm + +#endif // wasm_ir_bits_h + diff --git a/src/ir/block-utils.h b/src/ir/block-utils.h new file mode 100644 index 000000000..f7c68aa39 --- /dev/null +++ b/src/ir/block-utils.h @@ -0,0 +1,67 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ir_block_h +#define wasm_ir_block_h + +#include "literal.h" +#include "wasm.h" +#include "ir/branch-utils.h" +#include "ir/effects.h" + +namespace wasm { + +namespace BlockUtils { + // if a block has just one element, it can often be replaced + // with that content + template<typename T> + inline Expression* simplifyToContents(Block* block, T* parent, bool allowTypeChange = false) { + auto& list = block->list; + if (list.size() == 1 && !BranchUtils::BranchSeeker::hasNamed(list[0], block->name)) { + // just one element. try to replace the block + auto* singleton = list[0]; + auto sideEffects = EffectAnalyzer(parent->getPassOptions(), singleton).hasSideEffects(); + if (!sideEffects && !isConcreteWasmType(singleton->type)) { + // no side effects, and singleton is not returning a value, so we can throw away + // the block and its contents, basically + return Builder(*parent->getModule()).replaceWithIdenticalType(block); + } else if (block->type == singleton->type || allowTypeChange) { + return singleton; + } else { + // (side effects +) type change, must be block with declared value but inside is unreachable + // (if both concrete, must match, and since no name on block, we can't be + // branched to, so if singleton is unreachable, so is the block) + assert(isConcreteWasmType(block->type) && singleton->type == unreachable); + // we could replace with unreachable, but would need to update all + // the parent's types + } + } else if (list.size() == 0) { + ExpressionManipulator::nop(block); + } + return block; + } + + // similar, but when we allow the type to change while doing so + template<typename T> + inline Expression* simplifyToContentsWithPossibleTypeChange(Block* block, T* parent) { + return simplifyToContents(block, parent, true); + } +}; + +} // namespace wasm + +#endif // wasm_ir_block_h + diff --git a/src/ir/branch-utils.h b/src/ir/branch-utils.h new file mode 100644 index 000000000..26e8e7c87 --- /dev/null +++ b/src/ir/branch-utils.h @@ -0,0 +1,183 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ir_branch_h +#define wasm_ir_branch_h + +#include "wasm.h" +#include "wasm-traversal.h" + +namespace wasm { + +namespace BranchUtils { + +// Some branches are obviously not actually reachable (e.g. (br $out (unreachable))) + +inline bool isBranchReachable(Break* br) { + return !(br->value && br->value->type == unreachable) && + !(br->condition && br->condition->type == unreachable); +} + +inline bool isBranchReachable(Switch* sw) { + return !(sw->value && sw->value->type == unreachable) && + sw->condition->type != unreachable; +} + +inline bool isBranchReachable(Expression* expr) { + if (auto* br = expr->dynCast<Break>()) { + return isBranchReachable(br); + } else if (auto* sw = expr->dynCast<Switch>()) { + return isBranchReachable(sw); + } + WASM_UNREACHABLE(); +} + +// returns the set of targets to which we branch that are +// outside of a node +inline std::set<Name> getExitingBranches(Expression* ast) { + struct Scanner : public PostWalker<Scanner> { + std::set<Name> targets; + + void visitBreak(Break* curr) { + targets.insert(curr->name); + } + void visitSwitch(Switch* curr) { + for (auto target : targets) { + targets.insert(target); + } + targets.insert(curr->default_); + } + void visitBlock(Block* curr) { + if (curr->name.is()) { + targets.erase(curr->name); + } + } + void visitLoop(Loop* curr) { + if (curr->name.is()) { + targets.erase(curr->name); + } + } + }; + Scanner scanner; + scanner.walk(ast); + // anything not erased is a branch out + return scanner.targets; +} + +// returns the list of all branch targets in a node + +inline std::set<Name> getBranchTargets(Expression* ast) { + struct Scanner : public PostWalker<Scanner> { + std::set<Name> targets; + + void visitBlock(Block* curr) { + if (curr->name.is()) { + targets.insert(curr->name); + } + } + void visitLoop(Loop* curr) { + if (curr->name.is()) { + targets.insert(curr->name); + } + } + }; + Scanner scanner; + scanner.walk(ast); + return scanner.targets; +} + +// Finds if there are branches targeting a name. Note that since names are +// unique in our IR, we just need to look for the name, and do not need +// to analyze scoping. +// By default we consider all branches, so any place there is a branch that +// names the target. You can unset 'named' to only note branches that appear +// reachable (i.e., are not obviously unreachable). +struct BranchSeeker : public PostWalker<BranchSeeker> { + Name target; + bool named = true; + + Index found; + WasmType valueType; + + BranchSeeker(Name target) : target(target), found(0) {} + + void noteFound(Expression* value) { + found++; + if (found == 1) valueType = unreachable; + if (!value) valueType = none; + else if (value->type != unreachable) valueType = value->type; + } + + void visitBreak(Break *curr) { + if (!named) { + // ignore an unreachable break + if (curr->condition && curr->condition->type == unreachable) return; + if (curr->value && curr->value->type == unreachable) return; + } + // check the break + if (curr->name == target) noteFound(curr->value); + } + + void visitSwitch(Switch *curr) { + if (!named) { + // ignore an unreachable switch + if (curr->condition->type == unreachable) return; + if (curr->value && curr->value->type == unreachable) return; + } + // check the switch + for (auto name : curr->targets) { + if (name == target) noteFound(curr->value); + } + if (curr->default_ == target) noteFound(curr->value); + } + + static bool hasReachable(Expression* tree, Name target) { + if (!target.is()) return false; + BranchSeeker seeker(target); + seeker.named = false; + seeker.walk(tree); + return seeker.found > 0; + } + + static Index countReachable(Expression* tree, Name target) { + if (!target.is()) return 0; + BranchSeeker seeker(target); + seeker.named = false; + seeker.walk(tree); + return seeker.found; + } + + static bool hasNamed(Expression* tree, Name target) { + if (!target.is()) return false; + BranchSeeker seeker(target); + seeker.walk(tree); + return seeker.found > 0; + } + + static Index countNamed(Expression* tree, Name target) { + if (!target.is()) return 0; + BranchSeeker seeker(target); + seeker.walk(tree); + return seeker.found; + } +}; + +} // namespace BranchUtils + +} // namespace wasm + +#endif // wasm_ir_branch_h + diff --git a/src/ir/cost.h b/src/ir/cost.h new file mode 100644 index 000000000..9a97574f4 --- /dev/null +++ b/src/ir/cost.h @@ -0,0 +1,255 @@ +/* + * Copyright 2016 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ir_cost_h +#define wasm_ir_cost_h + +namespace wasm { + +// Measure the execution cost of an AST. Very handwave-ey + +struct CostAnalyzer : public Visitor<CostAnalyzer, Index> { + CostAnalyzer(Expression *ast) { + assert(ast); + cost = visit(ast); + } + + Index cost; + + Index maybeVisit(Expression* curr) { + return curr ? visit(curr) : 0; + } + + Index visitBlock(Block *curr) { + Index ret = 0; + for (auto* child : curr->list) ret += visit(child); + return ret; + } + Index visitIf(If *curr) { + return 1 + visit(curr->condition) + std::max(visit(curr->ifTrue), maybeVisit(curr->ifFalse)); + } + Index visitLoop(Loop *curr) { + return 5 * visit(curr->body); + } + Index visitBreak(Break *curr) { + return 1 + maybeVisit(curr->value) + maybeVisit(curr->condition); + } + Index visitSwitch(Switch *curr) { + return 2 + visit(curr->condition) + maybeVisit(curr->value); + } + Index visitCall(Call *curr) { + Index ret = 4; + for (auto* child : curr->operands) ret += visit(child); + return ret; + } + Index visitCallImport(CallImport *curr) { + Index ret = 15; + for (auto* child : curr->operands) ret += visit(child); + return ret; + } + Index visitCallIndirect(CallIndirect *curr) { + Index ret = 6 + visit(curr->target); + for (auto* child : curr->operands) ret += visit(child); + return ret; + } + Index visitGetLocal(GetLocal *curr) { + return 0; + } + Index visitSetLocal(SetLocal *curr) { + return 1; + } + Index visitGetGlobal(GetGlobal *curr) { + return 1; + } + Index visitSetGlobal(SetGlobal *curr) { + return 2; + } + Index visitLoad(Load *curr) { + return 1 + visit(curr->ptr) + 10 * curr->isAtomic; + } + Index visitStore(Store *curr) { + return 2 + visit(curr->ptr) + visit(curr->value) + 10 * curr->isAtomic; + } + Index visitAtomicRMW(AtomicRMW *curr) { + return 100; + } + Index visitAtomicCmpxchg(AtomicCmpxchg* curr) { + return 100; + } + Index visitConst(Const *curr) { + return 1; + } + Index visitUnary(Unary *curr) { + Index ret = 0; + switch (curr->op) { + case ClzInt32: + case CtzInt32: + case PopcntInt32: + case NegFloat32: + case AbsFloat32: + case CeilFloat32: + case FloorFloat32: + case TruncFloat32: + case NearestFloat32: + case ClzInt64: + case CtzInt64: + case PopcntInt64: + case NegFloat64: + case AbsFloat64: + case CeilFloat64: + case FloorFloat64: + case TruncFloat64: + case NearestFloat64: + case EqZInt32: + case EqZInt64: + case ExtendSInt32: + case ExtendUInt32: + case WrapInt64: + case PromoteFloat32: + case DemoteFloat64: + case TruncSFloat32ToInt32: + case TruncUFloat32ToInt32: + case TruncSFloat64ToInt32: + case TruncUFloat64ToInt32: + case ReinterpretFloat32: + case TruncSFloat32ToInt64: + case TruncUFloat32ToInt64: + case TruncSFloat64ToInt64: + case TruncUFloat64ToInt64: + case ReinterpretFloat64: + case ReinterpretInt32: + case ConvertSInt32ToFloat32: + case ConvertUInt32ToFloat32: + case ConvertSInt64ToFloat32: + case ConvertUInt64ToFloat32: + case ReinterpretInt64: + case ConvertSInt32ToFloat64: + case ConvertUInt32ToFloat64: + case ConvertSInt64ToFloat64: + case ConvertUInt64ToFloat64: ret = 1; break; + case SqrtFloat32: + case SqrtFloat64: ret = 2; break; + default: WASM_UNREACHABLE(); + } + return ret + visit(curr->value); + } + Index visitBinary(Binary *curr) { + Index ret = 0; + switch (curr->op) { + case AddInt32: ret = 1; break; + case SubInt32: ret = 1; break; + case MulInt32: ret = 2; break; + case DivSInt32: ret = 3; break; + case DivUInt32: ret = 3; break; + case RemSInt32: ret = 3; break; + case RemUInt32: ret = 3; break; + case AndInt32: ret = 1; break; + case OrInt32: ret = 1; break; + case XorInt32: ret = 1; break; + case ShlInt32: ret = 1; break; + case ShrUInt32: ret = 1; break; + case ShrSInt32: ret = 1; break; + case RotLInt32: ret = 1; break; + case RotRInt32: ret = 1; break; + case AddInt64: ret = 1; break; + case SubInt64: ret = 1; break; + case MulInt64: ret = 2; break; + case DivSInt64: ret = 3; break; + case DivUInt64: ret = 3; break; + case RemSInt64: ret = 3; break; + case RemUInt64: ret = 3; break; + case AndInt64: ret = 1; break; + case OrInt64: ret = 1; break; + case XorInt64: ret = 1; break; + case ShlInt64: ret = 1; break; + case ShrUInt64: ret = 1; break; + case ShrSInt64: ret = 1; break; + case RotLInt64: ret = 1; break; + case RotRInt64: ret = 1; break; + case AddFloat32: ret = 1; break; + case SubFloat32: ret = 1; break; + case MulFloat32: ret = 2; break; + case DivFloat32: ret = 3; break; + case CopySignFloat32: ret = 1; break; + case MinFloat32: ret = 1; break; + case MaxFloat32: ret = 1; break; + case AddFloat64: ret = 1; break; + case SubFloat64: ret = 1; break; + case MulFloat64: ret = 2; break; + case DivFloat64: ret = 3; break; + case CopySignFloat64: ret = 1; break; + case MinFloat64: ret = 1; break; + case MaxFloat64: ret = 1; break; + case LtUInt32: ret = 1; break; + case LtSInt32: ret = 1; break; + case LeUInt32: ret = 1; break; + case LeSInt32: ret = 1; break; + case GtUInt32: ret = 1; break; + case GtSInt32: ret = 1; break; + case GeUInt32: ret = 1; break; + case GeSInt32: ret = 1; break; + case LtUInt64: ret = 1; break; + case LtSInt64: ret = 1; break; + case LeUInt64: ret = 1; break; + case LeSInt64: ret = 1; break; + case GtUInt64: ret = 1; break; + case GtSInt64: ret = 1; break; + case GeUInt64: ret = 1; break; + case GeSInt64: ret = 1; break; + case LtFloat32: ret = 1; break; + case GtFloat32: ret = 1; break; + case LeFloat32: ret = 1; break; + case GeFloat32: ret = 1; break; + case LtFloat64: ret = 1; break; + case GtFloat64: ret = 1; break; + case LeFloat64: ret = 1; break; + case GeFloat64: ret = 1; break; + case EqInt32: ret = 1; break; + case NeInt32: ret = 1; break; + case EqInt64: ret = 1; break; + case NeInt64: ret = 1; break; + case EqFloat32: ret = 1; break; + case NeFloat32: ret = 1; break; + case EqFloat64: ret = 1; break; + case NeFloat64: ret = 1; break; + default: WASM_UNREACHABLE(); + } + return ret + visit(curr->left) + visit(curr->right); + } + Index visitSelect(Select *curr) { + return 2 + visit(curr->condition) + visit(curr->ifTrue) + visit(curr->ifFalse); + } + Index visitDrop(Drop *curr) { + return visit(curr->value); + } + Index visitReturn(Return *curr) { + return maybeVisit(curr->value); + } + Index visitHost(Host *curr) { + return 100; + } + Index visitNop(Nop *curr) { + return 0; + } + Index visitUnreachable(Unreachable *curr) { + return 0; + } +}; + +} // namespace wasm + +#endif // wasm_ir_cost_h + diff --git a/src/ir/count.h b/src/ir/count.h new file mode 100644 index 000000000..1fef3a870 --- /dev/null +++ b/src/ir/count.h @@ -0,0 +1,50 @@ +/* + * Copyright 2016 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ir_count_h +#define wasm_ir_count_h + +namespace wasm { + +struct GetLocalCounter : public PostWalker<GetLocalCounter> { + std::vector<Index> num; + + GetLocalCounter() {} + GetLocalCounter(Function* func) { + analyze(func, func->body); + } + GetLocalCounter(Function* func, Expression* ast) { + analyze(func, ast); + } + + void analyze(Function* func) { + analyze(func, func->body); + } + void analyze(Function* func, Expression* ast) { + num.resize(func->getNumLocals()); + std::fill(num.begin(), num.end(), 0); + walk(ast); + } + + void visitGetLocal(GetLocal *curr) { + num[curr->index]++; + } +}; + +} // namespace wasm + +#endif // wasm_ir_count_h + diff --git a/src/ir/effects.h b/src/ir/effects.h new file mode 100644 index 000000000..98911d451 --- /dev/null +++ b/src/ir/effects.h @@ -0,0 +1,278 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ir_effects_h +#define wasm_ir_effects_h + +namespace wasm { + +// Look for side effects, including control flow +// TODO: optimize + +struct EffectAnalyzer : public PostWalker<EffectAnalyzer> { + EffectAnalyzer(PassOptions& passOptions, Expression *ast = nullptr) { + ignoreImplicitTraps = passOptions.ignoreImplicitTraps; + debugInfo = passOptions.debugInfo; + if (ast) analyze(ast); + } + + bool ignoreImplicitTraps; + bool debugInfo; + + void analyze(Expression *ast) { + breakNames.clear(); + walk(ast); + // if we are left with breaks, they are external + if (breakNames.size() > 0) branches = true; + } + + bool branches = false; // branches out of this expression, returns, infinite loops, etc + bool calls = false; + std::set<Index> localsRead; + std::set<Index> localsWritten; + std::set<Name> globalsRead; + std::set<Name> globalsWritten; + bool readsMemory = false; + bool writesMemory = false; + bool implicitTrap = false; // a load or div/rem, which may trap. we ignore trap + // differences, so it is ok to reorder these, but we can't + // remove them, as they count as side effects, and we + // can't move them in a way that would cause other noticeable + // (global) side effects + bool isAtomic = false; // An atomic load/store/RMW/Cmpxchg or an operator that + // has a defined ordering wrt atomics (e.g. grow_memory) + + bool accessesLocal() { return localsRead.size() + localsWritten.size() > 0; } + bool accessesGlobal() { return globalsRead.size() + globalsWritten.size() > 0; } + bool accessesMemory() { return calls || readsMemory || writesMemory; } + bool hasGlobalSideEffects() { return calls || globalsWritten.size() > 0 || writesMemory || isAtomic; } + bool hasSideEffects() { return hasGlobalSideEffects() || localsWritten.size() > 0 || branches || implicitTrap; } + bool hasAnything() { return branches || calls || accessesLocal() || readsMemory || writesMemory || accessesGlobal() || implicitTrap || isAtomic; } + + // checks if these effects would invalidate another set (e.g., if we write, we invalidate someone that reads, they can't be moved past us) + bool invalidates(EffectAnalyzer& other) { + if (branches || other.branches + || ((writesMemory || calls) && other.accessesMemory()) + || (accessesMemory() && (other.writesMemory || other.calls))) { + return true; + } + // All atomics are sequentially consistent for now, and ordered wrt other + // memory references. + if ((isAtomic && other.accessesMemory()) || + (other.isAtomic && accessesMemory())) { + return true; + } + for (auto local : localsWritten) { + if (other.localsWritten.count(local) || other.localsRead.count(local)) { + return true; + } + } + for (auto local : localsRead) { + if (other.localsWritten.count(local)) return true; + } + if ((accessesGlobal() && other.calls) || (other.accessesGlobal() && calls)) { + return true; + } + for (auto global : globalsWritten) { + if (other.globalsWritten.count(global) || other.globalsRead.count(global)) { + return true; + } + } + for (auto global : globalsRead) { + if (other.globalsWritten.count(global)) return true; + } + // we are ok to reorder implicit traps, but not conditionalize them + if ((implicitTrap && other.branches) || (other.implicitTrap && branches)) { + return true; + } + // we can't reorder an implicit trap in a way that alters global state + if ((implicitTrap && other.hasGlobalSideEffects()) || (other.implicitTrap && hasGlobalSideEffects())) { + return true; + } + return false; + } + + void mergeIn(EffectAnalyzer& other) { + branches = branches || other.branches; + calls = calls || other.calls; + readsMemory = readsMemory || other.readsMemory; + writesMemory = writesMemory || other.writesMemory; + for (auto i : other.localsRead) localsRead.insert(i); + for (auto i : other.localsWritten) localsWritten.insert(i); + for (auto i : other.globalsRead) globalsRead.insert(i); + for (auto i : other.globalsWritten) globalsWritten.insert(i); + } + + // the checks above happen after the node's children were processed, in the order of execution + // we must also check for control flow that happens before the children, i.e., loops + bool checkPre(Expression* curr) { + if (curr->is<Loop>()) { + branches = true; + return true; + } + return false; + } + + bool checkPost(Expression* curr) { + visit(curr); + if (curr->is<Loop>()) { + branches = true; + } + return hasAnything(); + } + + std::set<Name> breakNames; + + void visitBreak(Break *curr) { + breakNames.insert(curr->name); + } + void visitSwitch(Switch *curr) { + for (auto name : curr->targets) { + breakNames.insert(name); + } + breakNames.insert(curr->default_); + } + void visitBlock(Block* curr) { + if (curr->name.is()) breakNames.erase(curr->name); // these were internal breaks + } + void visitLoop(Loop* curr) { + if (curr->name.is()) breakNames.erase(curr->name); // these were internal breaks + // if the loop is unreachable, then there is branching control flow: + // (1) if the body is unreachable because of a (return), uncaught (br) etc., then we + // already noted branching, so it is ok to mark it again (if we have *caught* + // (br)s, then they did not lead to the loop body being unreachable). + // (same logic applies to blocks) + // (2) if the loop is unreachable because it only has branches up to the loop + // top, but no way to get out, then it is an infinite loop, and we consider + // that a branching side effect (note how the same logic does not apply to + // blocks). + if (curr->type == unreachable) { + branches = true; + } + } + + void visitCall(Call *curr) { calls = true; } + void visitCallImport(CallImport *curr) { + calls = true; + if (debugInfo) { + // debugInfo call imports must be preserved very strongly, do not + // move code around them + branches = true; // ! + } + } + void visitCallIndirect(CallIndirect *curr) { calls = true; } + void visitGetLocal(GetLocal *curr) { + localsRead.insert(curr->index); + } + void visitSetLocal(SetLocal *curr) { + localsWritten.insert(curr->index); + } + void visitGetGlobal(GetGlobal *curr) { + globalsRead.insert(curr->name); + } + void visitSetGlobal(SetGlobal *curr) { + globalsWritten.insert(curr->name); + } + void visitLoad(Load *curr) { + readsMemory = true; + isAtomic |= curr->isAtomic; + if (!ignoreImplicitTraps) implicitTrap = true; + } + void visitStore(Store *curr) { + writesMemory = true; + isAtomic |= curr->isAtomic; + if (!ignoreImplicitTraps) implicitTrap = true; + } + void visitAtomicRMW(AtomicRMW* curr) { + readsMemory = true; + writesMemory = true; + isAtomic = true; + if (!ignoreImplicitTraps) implicitTrap = true; + } + void visitAtomicCmpxchg(AtomicCmpxchg* curr) { + readsMemory = true; + writesMemory = true; + isAtomic = true; + if (!ignoreImplicitTraps) implicitTrap = true; + } + void visitAtomicWait(AtomicWait* curr) { + readsMemory = true; + // AtomicWait doesn't strictly write memory, but it does modify the waiters + // list associated with the specified address, which we can think of as a + // write. + writesMemory = true; + isAtomic = true; + if (!ignoreImplicitTraps) implicitTrap = true; + } + void visitAtomicWake(AtomicWake* curr) { + // AtomicWake doesn't strictly write memory, but it does modify the waiters + // list associated with the specified address, which we can think of as a + // write. + readsMemory = true; + writesMemory = true; + isAtomic = true; + if (!ignoreImplicitTraps) implicitTrap = true; + }; + void visitUnary(Unary *curr) { + if (!ignoreImplicitTraps) { + switch (curr->op) { + case TruncSFloat32ToInt32: + case TruncSFloat32ToInt64: + case TruncUFloat32ToInt32: + case TruncUFloat32ToInt64: + case TruncSFloat64ToInt32: + case TruncSFloat64ToInt64: + case TruncUFloat64ToInt32: + case TruncUFloat64ToInt64: { + implicitTrap = true; + break; + } + default: {} + } + } + } + void visitBinary(Binary *curr) { + if (!ignoreImplicitTraps) { + switch (curr->op) { + case DivSInt32: + case DivUInt32: + case RemSInt32: + case RemUInt32: + case DivSInt64: + case DivUInt64: + case RemSInt64: + case RemUInt64: { + implicitTrap = true; + break; + } + default: {} + } + } + } + void visitReturn(Return *curr) { branches = true; } + void visitHost(Host *curr) { + calls = true; + // grow_memory modifies the set of valid addresses, and thus can be modeled as modifying memory + writesMemory = true; + // Atomics are also sequentially consistent with grow_memory. + isAtomic = true; + } + void visitUnreachable(Unreachable *curr) { branches = true; } +}; + +} // namespace wasm + +#endif // wasm_ir_effects_h diff --git a/src/ir/find_all.h b/src/ir/find_all.h new file mode 100644 index 000000000..83c751666 --- /dev/null +++ b/src/ir/find_all.h @@ -0,0 +1,48 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ir_find_all_h +#define wasm_ir_find_all_h + +#include <wasm-traversal.h> + +namespace wasm { + +// Find all instances of a certain node type + +template<typename T> +struct FindAll { + std::vector<T*> list; + + FindAll(Expression* ast) { + struct Finder : public PostWalker<Finder, UnifiedExpressionVisitor<Finder>> { + std::vector<T*>* list; + void visitExpression(Expression* curr) { + if (curr->is<T>()) { + (*list).push_back(curr->cast<T>()); + } + } + }; + Finder finder; + finder.list = &list; + finder.walk(ast); + } +}; + +} // namespace wasm + +#endif // wasm_ir_find_all_h + diff --git a/src/ir/global-utils.h b/src/ir/global-utils.h new file mode 100644 index 000000000..bcf0dae72 --- /dev/null +++ b/src/ir/global-utils.h @@ -0,0 +1,55 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ir_global_h +#define wasm_ir_global_h + +#include <algorithm> +#include <vector> + +#include "literal.h" +#include "wasm.h" + +namespace wasm { + +namespace GlobalUtils { + // find a global initialized to the value of an import, or null if no such global + inline Global* getGlobalInitializedToImport(Module& wasm, Name module, Name base) { + // find the import + Name imported; + for (auto& import : wasm.imports) { + if (import->module == module && import->base == base) { + imported = import->name; + break; + } + } + if (imported.isNull()) return nullptr; + // find a global inited to it + for (auto& global : wasm.globals) { + if (auto* init = global->init->dynCast<GetGlobal>()) { + if (init->name == imported) { + return global.get(); + } + } + } + return nullptr; + } +}; + +} // namespace wasm + +#endif // wasm_ir_global_h + diff --git a/src/ir/hashed.h b/src/ir/hashed.h new file mode 100644 index 000000000..dc4012455 --- /dev/null +++ b/src/ir/hashed.h @@ -0,0 +1,59 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef _wasm_ir_hashed_h + +#include "support/hash.h" +#include "wasm.h" +#include "ir/utils.h" + +namespace wasm { + +// An expression with a cached hash value +struct HashedExpression { + Expression* expr; + size_t hash; + + HashedExpression(Expression* expr) : expr(expr) { + if (expr) { + hash = ExpressionAnalyzer::hash(expr); + } + } + + HashedExpression(const HashedExpression& other) : expr(other.expr), hash(other.hash) {} +}; + +struct ExpressionHasher { + size_t operator()(const HashedExpression value) const { + return value.hash; + } +}; + +struct ExpressionComparer { + bool operator()(const HashedExpression a, const HashedExpression b) const { + if (a.hash != b.hash) return false; + return ExpressionAnalyzer::equal(a.expr, b.expr); + } +}; + +template<typename T> +class HashedExpressionMap : public std::unordered_map<HashedExpression, T, ExpressionHasher, ExpressionComparer> { +}; + +} // namespace wasm + +#endif // _wasm_ir_hashed_h + diff --git a/src/ir/import-utils.h b/src/ir/import-utils.h new file mode 100644 index 000000000..f3f01c266 --- /dev/null +++ b/src/ir/import-utils.h @@ -0,0 +1,41 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ir_import_h +#define wasm_ir_import_h + +#include "literal.h" +#include "wasm.h" + +namespace wasm { + +namespace ImportUtils { + // find an import by the module.base that is being imported. + // return the internal name + inline Import* getImport(Module& wasm, Name module, Name base) { + for (auto& import : wasm.imports) { + if (import->module == module && import->base == base) { + return import.get(); + } + } + return nullptr; + } +}; + +} // namespace wasm + +#endif // wasm_ir_import_h + diff --git a/src/ir/label-utils.h b/src/ir/label-utils.h new file mode 100644 index 000000000..f4fb77697 --- /dev/null +++ b/src/ir/label-utils.h @@ -0,0 +1,62 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ir_label_h +#define wasm_ir_label_h + +#include "wasm.h" +#include "wasm-traversal.h" + +namespace wasm { + +namespace LabelUtils { + +// Handles branch/loop labels in a function; makes it easy to add new +// ones without duplicates +class LabelManager : public PostWalker<LabelManager> { +public: + LabelManager(Function* func) { + walkFunction(func); + } + + Name getUnique(std::string prefix) { + while (1) { + auto curr = Name(prefix + std::to_string(counter++)); + if (labels.find(curr) == labels.end()) { + labels.insert(curr); + return curr; + } + } + } + + void visitBlock(Block* curr) { + labels.insert(curr->name); + } + void visitLoop(Loop* curr) { + labels.insert(curr->name); + } + +private: + std::set<Name> labels; + size_t counter = 0; +}; + +} // namespace LabelUtils + +} // namespace wasm + +#endif // wasm_ir_label_h + diff --git a/src/ir/literal-utils.h b/src/ir/literal-utils.h new file mode 100644 index 000000000..a702c52eb --- /dev/null +++ b/src/ir/literal-utils.h @@ -0,0 +1,56 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ir_literal_utils_h +#define wasm_ir_literal_utils_h + +#include "wasm.h" + +namespace wasm { + +namespace LiteralUtils { + +inline Literal makeLiteralFromInt32(int32_t x, WasmType type) { + switch (type) { + case i32: return Literal(int32_t(x)); break; + case i64: return Literal(int64_t(x)); break; + case f32: return Literal(float(x)); break; + case f64: return Literal(double(x)); break; + default: WASM_UNREACHABLE(); + } +} + +inline Literal makeLiteralZero(WasmType type) { + return makeLiteralFromInt32(0, type); +} + +inline Expression* makeFromInt32(int32_t x, WasmType type, Module& wasm) { + auto* ret = wasm.allocator.alloc<Const>(); + ret->value = makeLiteralFromInt32(x, type); + ret->type = type; + return ret; +} + +inline Expression* makeZero(WasmType type, Module& wasm) { + return makeFromInt32(0, type, wasm); +} + +} // namespace LiteralUtils + +} // namespace wasm + +#endif // wasm_ir_literal_utils_h + diff --git a/src/ir/load-utils.h b/src/ir/load-utils.h new file mode 100644 index 000000000..edc7eb90f --- /dev/null +++ b/src/ir/load-utils.h @@ -0,0 +1,40 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ir_load_h +#define wasm_ir_load_h + +#include "wasm.h" + +namespace wasm { + +namespace LoadUtils { + +// checks if the sign of a load matters, which is when an integer +// load is of fewer bytes than the size of the type (so we must +// fill in bits either signed or unsigned wise) +inline bool isSignRelevant(Load* load) { + auto type = load->type; + if (load->type == unreachable) return false; + return !isWasmTypeFloat(type) && load->bytes < getWasmTypeSize(type); +} + +} // namespace LoadUtils + +} // namespace wasm + +#endif // wasm_ir_load_h + diff --git a/src/ir/local-graph.h b/src/ir/local-graph.h new file mode 100644 index 000000000..4c4c1ee0a --- /dev/null +++ b/src/ir/local-graph.h @@ -0,0 +1,111 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ir_local_graph_h +#define wasm_ir_local_graph_h + +namespace wasm { + +// +// Finds the connections between get_locals and set_locals, creating +// a graph of those ties. This is useful for "ssa-style" optimization, +// in which you want to know exactly which sets are relevant for a +// a get, so it is as if each get has just one set, logically speaking +// (see the SSA pass for actually creating new local indexes based +// on this). +// +// TODO: the algorithm here is pretty simple, but also pretty slow, +// we should optimize it. e.g. we rely on set_interaction +// here, and worse we only use it to compute the size... +struct LocalGraph : public PostWalker<LocalGraph> { + // main API + + // the constructor computes getSetses, the sets affecting each get + LocalGraph(Function* func, Module* module); + + // the set_locals relevant for an index or a get. + typedef std::set<SetLocal*> Sets; + + // externally useful information + std::map<GetLocal*, Sets> getSetses; // the sets affecting each get. a nullptr set means the initial + // value (0 for a var, the received value for a param) + std::map<Expression*, Expression**> locations; // where each get and set is (for easy replacing) + + // optional computation: compute the influence graphs between sets and gets + // (useful for algorithms that propagate changes) + + std::unordered_map<GetLocal*, std::unordered_set<SetLocal*>> getInfluences; // for each get, the sets whose values are influenced by that get + std::unordered_map<SetLocal*, std::unordered_set<GetLocal*>> setInfluences; // for each set, the gets whose values are influenced by that set + + void computeInfluences(); + +private: + // we map local index => the set_locals for that index. + // a nullptr set means there is a virtual set, from a param + // initial value or the zero init initial value. + typedef std::vector<Sets> Mapping; + + // internal state + Index numLocals; + Mapping currMapping; + std::vector<Mapping> mappingStack; // used in ifs, loops + std::map<Name, std::vector<Mapping>> breakMappings; // break target => infos that reach it + std::vector<std::vector<GetLocal*>> loopGetStack; // stack of loops, all the gets in each, so we can update them for back branches + +public: + void doWalkFunction(Function* func); + + // control flow + + void visitBlock(Block* curr); + + void finishIf(); + + static void afterIfCondition(LocalGraph* self, Expression** currp); + static void afterIfTrue(LocalGraph* self, Expression** currp); + static void afterIfFalse(LocalGraph* self, Expression** currp); + static void beforeLoop(LocalGraph* self, Expression** currp); + void visitLoop(Loop* curr); + void visitBreak(Break* curr); + void visitSwitch(Switch* curr); + void visitReturn(Return *curr); + void visitUnreachable(Unreachable *curr); + + // local usage + + void visitGetLocal(GetLocal* curr); + void visitSetLocal(SetLocal* curr); + + // traversal + + static void scan(LocalGraph* self, Expression** currp); + + // helpers + + void setUnreachable(Mapping& mapping); + + bool isUnreachable(Mapping& mapping); + + // merges a bunch of infos into one. + // if we need phis, writes them into the provided vector. the caller should + // ensure those are placed in the right location + Mapping& merge(std::vector<Mapping>& mappings); +}; + +} // namespace wasm + +#endif // wasm_ir_local_graph_h + diff --git a/src/ir/localize.h b/src/ir/localize.h new file mode 100644 index 000000000..c910d9f9b --- /dev/null +++ b/src/ir/localize.h @@ -0,0 +1,47 @@ +/* + * Copyright 2016 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ir_localizer_h +#define wasm_ir_localizer_h + +#include <wasm-builder.h> + +namespace wasm { + +// Make an expression available in a local. If already in one, just +// use that local, otherwise use a new local + +struct Localizer { + Index index; + Expression* expr; + + Localizer(Expression* input, Function* func, Module* wasm) { + expr = input; + if (auto* get = expr->dynCast<GetLocal>()) { + index = get->index; + } else if (auto* set = expr->dynCast<SetLocal>()) { + index = set->index; + } else { + index = Builder::addVar(func, expr->type); + expr = Builder(*wasm).makeTeeLocal(index, expr); + } + } +}; + +} // namespace wasm + +#endif // wasm_ir_localizer_h + diff --git a/src/ir/manipulation.h b/src/ir/manipulation.h new file mode 100644 index 000000000..57188ad68 --- /dev/null +++ b/src/ir/manipulation.h @@ -0,0 +1,69 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ir_manipulation_h +#define wasm_ir_manipulation_h + +#include "wasm.h" + +namespace wasm { + +namespace ExpressionManipulator { + // Re-use a node's memory. This helps avoid allocation when optimizing. + template<typename InputType, typename OutputType> + inline OutputType* convert(InputType *input) { + static_assert(sizeof(OutputType) <= sizeof(InputType), + "Can only convert to a smaller size Expression node"); + input->~InputType(); // arena-allocaed, so no destructor, but avoid UB. + OutputType* output = (OutputType*)(input); + new (output) OutputType; + return output; + } + + // Convenience method for nop, which is a common conversion + template<typename InputType> + inline Nop* nop(InputType* target) { + return convert<InputType, Nop>(target); + } + + // Convert a node that allocates + template<typename InputType, typename OutputType> + inline OutputType* convert(InputType *input, MixedArena& allocator) { + assert(sizeof(OutputType) <= sizeof(InputType)); + input->~InputType(); // arena-allocaed, so no destructor, but avoid UB. + OutputType* output = (OutputType*)(input); + new (output) OutputType(allocator); + return output; + } + + using CustomCopier = std::function<Expression*(Expression*)>; + Expression* flexibleCopy(Expression* original, Module& wasm, CustomCopier custom); + + inline Expression* copy(Expression* original, Module& wasm) { + auto copy = [](Expression* curr) { + return nullptr; + }; + return flexibleCopy(original, wasm, copy); + } + + // Splice an item into the middle of a block's list + void spliceIntoBlock(Block* block, Index index, Expression* add); +} + +} // wasm + +#endif // wams_ir_manipulation_h + diff --git a/src/ir/memory-utils.h b/src/ir/memory-utils.h new file mode 100644 index 000000000..920583f7d --- /dev/null +++ b/src/ir/memory-utils.h @@ -0,0 +1,56 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ir_memory_h +#define wasm_ir_memory_h + +#include <algorithm> +#include <vector> + +#include "literal.h" +#include "wasm.h" + +namespace wasm { + +namespace MemoryUtils { + // flattens memory into a single data segment. returns true if successful + inline bool flatten(Memory& memory) { + if (memory.segments.size() == 0) return true; + std::vector<char> data; + for (auto& segment : memory.segments) { + auto* offset = segment.offset->dynCast<Const>(); + if (!offset) return false; + } + for (auto& segment : memory.segments) { + auto* offset = segment.offset->dynCast<Const>(); + auto start = offset->value.getInteger(); + auto end = start + segment.data.size(); + if (end > data.size()) { + data.resize(end); + } + std::copy(segment.data.begin(), segment.data.end(), data.begin() + start); + } + memory.segments.resize(1); + memory.segments[0].offset->cast<Const>()->value = Literal(int32_t(0)); + memory.segments[0].data.swap(data); + return true; + } +}; + +} // namespace wasm + +#endif // wasm_ir_memory_h + diff --git a/src/ir/module-utils.h b/src/ir/module-utils.h new file mode 100644 index 000000000..0c828f83a --- /dev/null +++ b/src/ir/module-utils.h @@ -0,0 +1,59 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ir_module_h +#define wasm_ir_module_h + +#include "wasm.h" + +namespace wasm { + +namespace ModuleUtils { + +// Computes the indexes in a wasm binary, i.e., with function imports +// and function implementations sharing a single index space, etc. +struct BinaryIndexes { + std::unordered_map<Name, Index> functionIndexes; + std::unordered_map<Name, Index> globalIndexes; + + BinaryIndexes(Module& wasm) { + for (Index i = 0; i < wasm.imports.size(); i++) { + auto& import = wasm.imports[i]; + if (import->kind == ExternalKind::Function) { + auto index = functionIndexes.size(); + functionIndexes[import->name] = index; + } else if (import->kind == ExternalKind::Global) { + auto index = globalIndexes.size(); + globalIndexes[import->name] = index; + } + } + for (Index i = 0; i < wasm.functions.size(); i++) { + auto index = functionIndexes.size(); + functionIndexes[wasm.functions[i]->name] = index; + } + for (Index i = 0; i < wasm.globals.size(); i++) { + auto index = globalIndexes.size(); + globalIndexes[wasm.globals[i]->name] = index; + } + } +}; + +} // namespace ModuleUtils + +} // namespace wasm + +#endif // wasm_ir_module_h + diff --git a/src/ir/properties.h b/src/ir/properties.h new file mode 100644 index 000000000..cf481218c --- /dev/null +++ b/src/ir/properties.h @@ -0,0 +1,141 @@ +/* + * Copyright 2016 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ir_properties_h +#define wasm_ir_properties_h + +#include "wasm.h" +#include "ir/bits.h" + +namespace wasm { + +struct Properties { + static bool emitsBoolean(Expression* curr) { + if (auto* unary = curr->dynCast<Unary>()) { + return unary->isRelational(); + } else if (auto* binary = curr->dynCast<Binary>()) { + return binary->isRelational(); + } + return false; + } + + static bool isSymmetric(Binary* binary) { + switch (binary->op) { + case AddInt32: + case MulInt32: + case AndInt32: + case OrInt32: + case XorInt32: + case EqInt32: + case NeInt32: + + case AddInt64: + case MulInt64: + case AndInt64: + case OrInt64: + case XorInt64: + case EqInt64: + case NeInt64: return true; + + default: return false; + } + } + + // Check if an expression is a sign-extend, and if so, returns the value + // that is extended, otherwise nullptr + static Expression* getSignExtValue(Expression* curr) { + if (auto* outer = curr->dynCast<Binary>()) { + if (outer->op == ShrSInt32) { + if (auto* outerConst = outer->right->dynCast<Const>()) { + if (outerConst->value.geti32() != 0) { + if (auto* inner = outer->left->dynCast<Binary>()) { + if (inner->op == ShlInt32) { + if (auto* innerConst = inner->right->dynCast<Const>()) { + if (outerConst->value == innerConst->value) { + return inner->left; + } + } + } + } + } + } + } + } + return nullptr; + } + + // gets the size of the sign-extended value + static Index getSignExtBits(Expression* curr) { + return 32 - Bits::getEffectiveShifts(curr->cast<Binary>()->right); + } + + // Check if an expression is almost a sign-extend: perhaps the inner shift + // is too large. We can split the shifts in that case, which is sometimes + // useful (e.g. if we can remove the signext) + static Expression* getAlmostSignExt(Expression* curr) { + if (auto* outer = curr->dynCast<Binary>()) { + if (outer->op == ShrSInt32) { + if (auto* outerConst = outer->right->dynCast<Const>()) { + if (outerConst->value.geti32() != 0) { + if (auto* inner = outer->left->dynCast<Binary>()) { + if (inner->op == ShlInt32) { + if (auto* innerConst = inner->right->dynCast<Const>()) { + if (Bits::getEffectiveShifts(outerConst) <= Bits::getEffectiveShifts(innerConst)) { + return inner->left; + } + } + } + } + } + } + } + } + return nullptr; + } + + // gets the size of the almost sign-extended value, as well as the + // extra shifts, if any + static Index getAlmostSignExtBits(Expression* curr, Index& extraShifts) { + extraShifts = Bits::getEffectiveShifts(curr->cast<Binary>()->left->cast<Binary>()->right) - + Bits::getEffectiveShifts(curr->cast<Binary>()->right); + return getSignExtBits(curr); + } + + // Check if an expression is a zero-extend, and if so, returns the value + // that is extended, otherwise nullptr + static Expression* getZeroExtValue(Expression* curr) { + if (auto* binary = curr->dynCast<Binary>()) { + if (binary->op == AndInt32) { + if (auto* c = binary->right->dynCast<Const>()) { + if (Bits::getMaskedBits(c->value.geti32())) { + return binary->right; + } + } + } + } + return nullptr; + } + + // gets the size of the sign-extended value + static Index getZeroExtBits(Expression* curr) { + return Bits::getMaskedBits(curr->cast<Binary>()->right->cast<Const>()->value.geti32()); + } +}; + +} // wasm + +#endif // wams_ir_properties_h + diff --git a/src/ir/trapping.h b/src/ir/trapping.h new file mode 100644 index 000000000..a3a87f8ef --- /dev/null +++ b/src/ir/trapping.h @@ -0,0 +1,120 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ir_trapping_h +#define wasm_ir_trapping_h + +#include <exception> + +#include "pass.h" + +namespace wasm { + +enum class TrapMode { + Allow, + Clamp, + JS +}; + +inline void addTrapModePass(PassRunner& runner, TrapMode trapMode) { + if (trapMode == TrapMode::Clamp) { + runner.add("trap-mode-clamp"); + } else if (trapMode == TrapMode::JS) { + runner.add("trap-mode-js"); + } +} + +class TrappingFunctionContainer { +public: + TrappingFunctionContainer(TrapMode mode, Module &wasm, bool immediate = false) + : mode(mode), + wasm(wasm), + immediate(immediate) { } + + bool hasFunction(Name name) { + return functions.find(name) != functions.end(); + } + bool hasImport(Name name) { + return imports.find(name) != imports.end(); + } + + void addFunction(Function* function) { + functions[function->name] = function; + if (immediate) { + wasm.addFunction(function); + } + } + void addImport(Import* import) { + imports[import->name] = import; + if (immediate) { + wasm.addImport(import); + } + } + + void addToModule() { + if (!immediate) { + for (auto &pair : functions) { + wasm.addFunction(pair.second); + } + for (auto &pair : imports) { + wasm.addImport(pair.second); + } + } + functions.clear(); + imports.clear(); + } + + TrapMode getMode() { + return mode; + } + + Module& getModule() { + return wasm; + } + + std::map<Name, Function*>& getFunctions() { + return functions; + } + +private: + std::map<Name, Function*> functions; + std::map<Name, Import*> imports; + + TrapMode mode; + Module& wasm; + bool immediate; +}; + +Expression* makeTrappingBinary(Binary* curr, TrappingFunctionContainer &trappingFunctions); +Expression* makeTrappingUnary(Unary* curr, TrappingFunctionContainer &trappingFunctions); + +inline TrapMode trapModeFromString(std::string const& str) { + if (str == "allow") { + return TrapMode::Allow; + } else if (str == "clamp") { + return TrapMode::Clamp; + } else if (str == "js") { + return TrapMode::JS; + } else { + throw std::invalid_argument( + "Unsupported trap mode \"" + str + "\". " + "Valid modes are \"allow\", \"js\", and \"clamp\""); + } +} + +} // wasm + +#endif // wasm_ir_trapping_h diff --git a/src/ir/type-updating.h b/src/ir/type-updating.h new file mode 100644 index 000000000..79b26aa43 --- /dev/null +++ b/src/ir/type-updating.h @@ -0,0 +1,286 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ir_type_updating_h +#define wasm_ir_type_updating_h + +#include "wasm-traversal.h" + +namespace wasm { + +// a class that tracks type dependencies between nodes, letting you +// update types efficiently when removing and altering code. +// altering code can alter types in the following way: +// * removing a break can make a block unreachable, if nothing else +// reaches it +// * altering the type of a child to unreachable can make the parent +// unreachable +struct TypeUpdater : public ExpressionStackWalker<TypeUpdater, UnifiedExpressionVisitor<TypeUpdater>> { + // Part 1: Scanning + + // track names to their blocks, so that when we remove a break to + // a block, we know how to find it if we need to update it + struct BlockInfo { + Block* block = nullptr; + int numBreaks = 0; + }; + std::map<Name, BlockInfo> blockInfos; + + // track the parent of each node, as child type changes may lead to + // unreachability + std::map<Expression*, Expression*> parents; + + void visitExpression(Expression* curr) { + if (expressionStack.size() > 1) { + parents[curr] = expressionStack[expressionStack.size() - 2]; + } else { + parents[curr] = nullptr; // this is the top level + } + // discover block/break relationships + if (auto* block = curr->dynCast<Block>()) { + if (block->name.is()) { + blockInfos[block->name].block = block; + } + } else if (auto* br = curr->dynCast<Break>()) { + // ensure info exists, discoverBreaks can then fill it + blockInfos[br->name]; + } else if (auto* sw = curr->dynCast<Switch>()) { + // ensure info exists, discoverBreaks can then fill it + for (auto target : sw->targets) { + blockInfos[target]; + } + blockInfos[sw->default_]; + } + // add a break to the info, for break and switch + discoverBreaks(curr, +1); + } + + // Part 2: Updating + + // Node replacements, additions, removals and type changes should be noted. An + // exception is nodes you know will never be looked at again. + + // note the replacement of one node with another. this should be called + // after performing the replacement. + // this does *not* look into the node by default. see noteReplacementWithRecursiveRemoval + // (we don't support recursive addition because in practice we do not create + // new trees in the passes that use this, they just move around children) + void noteReplacement(Expression* from, Expression* to, bool recursivelyRemove=false) { + auto parent = parents[from]; + if (recursivelyRemove) { + noteRecursiveRemoval(from); + } else { + noteRemoval(from); + } + // if we are replacing with a child, i.e. a node that was already present + // in the ast, then we just have a type and parent to update + if (parents.find(to) != parents.end()) { + parents[to] = parent; + if (from->type != to->type) { + propagateTypesUp(to); + } + } else { + noteAddition(to, parent, from); + } + } + + void noteReplacementWithRecursiveRemoval(Expression* from, Expression* to) { + noteReplacement(from, to, true); + } + + // note the removal of a node + void noteRemoval(Expression* curr) { + noteRemovalOrAddition(curr, nullptr); + parents.erase(curr); + } + + // note the removal of a node and all its children + void noteRecursiveRemoval(Expression* curr) { + struct Recurser : public PostWalker<Recurser, UnifiedExpressionVisitor<Recurser>> { + TypeUpdater& parent; + + Recurser(TypeUpdater& parent, Expression* root) : parent(parent) { + walk(root); + } + + void visitExpression(Expression* curr) { + parent.noteRemoval(curr); + } + }; + + Recurser(*this, curr); + } + + void noteAddition(Expression* curr, Expression* parent, Expression* previous = nullptr) { + assert(parents.find(curr) == parents.end()); // must not already exist + noteRemovalOrAddition(curr, parent); + // if we didn't replace with the exact same type, propagate types up + if (!(previous && previous->type == curr->type)) { + propagateTypesUp(curr); + } + } + + // if parent is nullptr, this is a removal + void noteRemovalOrAddition(Expression* curr, Expression* parent) { + parents[curr] = parent; + discoverBreaks(curr, parent ? +1 : -1); + } + + // adds (or removes) breaks depending on break/switch contents + void discoverBreaks(Expression* curr, int change) { + if (auto* br = curr->dynCast<Break>()) { + noteBreakChange(br->name, change, br->value); + } else if (auto* sw = curr->dynCast<Switch>()) { + applySwitchChanges(sw, change); + } + } + + void applySwitchChanges(Switch* sw, int change) { + std::set<Name> seen; + for (auto target : sw->targets) { + if (seen.insert(target).second) { + noteBreakChange(target, change, sw->value); + } + } + if (seen.insert(sw->default_).second) { + noteBreakChange(sw->default_, change, sw->value); + } + } + + // note the addition of a node + void noteBreakChange(Name name, int change, Expression* value) { + auto iter = blockInfos.find(name); + if (iter == blockInfos.end()) { + return; // we can ignore breaks to loops + } + auto& info = iter->second; + info.numBreaks += change; + assert(info.numBreaks >= 0); + auto* block = info.block; + if (block) { // if to a loop, can ignore + if (info.numBreaks == 0) { + // dropped to 0! the block may now be unreachable. that + // requires that it doesn't have a fallthrough + makeBlockUnreachableIfNoFallThrough(block); + } else if (change == 1 && info.numBreaks == 1) { + // bumped to 1! the block may now be reachable + if (block->type != unreachable) { + return; // was already reachable, had a fallthrough + } + changeTypeTo(block, value ? value->type : none); + } + } + } + + // alters the type of a node to a new type. + // this propagates the type change through all the parents. + void changeTypeTo(Expression* curr, WasmType newType) { + if (curr->type == newType) return; // nothing to do + curr->type = newType; + propagateTypesUp(curr); + } + + // given a node that has a new type, or is a new node, update + // all the parents accordingly. the existence of the node and + // any changes to it already occurred, this just updates the + // parents following that. i.e., nothing is done to the + // node we start on, it's done. + // the one thing we need to do here is propagate unreachability, + // no other change is possible + void propagateTypesUp(Expression* curr) { + if (curr->type != unreachable) return; + while (1) { + auto* child = curr; + curr = parents[child]; + if (!curr) return; + // get ready to apply unreachability to this node + if (curr->type == unreachable) { + return; // already unreachable, stop here + } + // most nodes become unreachable if a child is unreachable, + // but exceptions exist + if (auto* block = curr->dynCast<Block>()) { + // if the block has a fallthrough, it can keep its type + if (isConcreteWasmType(block->list.back()->type)) { + return; // did not turn + } + // if the block has breaks, it can keep its type + if (!block->name.is() || blockInfos[block->name].numBreaks == 0) { + curr->type = unreachable; + } else { + return; // did not turn + } + } else if (auto* iff = curr->dynCast<If>()) { + // may not be unreachable if just one side is + iff->finalize(); + if (curr->type != unreachable) { + return; // did not turn + } + } else { + curr->type = unreachable; + } + } + } + + // efficiently update the type of a block, given the data we know. this + // can remove a concrete type and turn the block unreachable when it is + // unreachable, and it does this efficiently, without scanning the full + // contents + void maybeUpdateTypeToUnreachable(Block* curr) { + if (!isConcreteWasmType(curr->type)) { + return; // nothing concrete to change to unreachable + } + if (curr->name.is() && blockInfos[curr->name].numBreaks > 0) { + return; // has a break, not unreachable + } + // look for a fallthrough + makeBlockUnreachableIfNoFallThrough(curr); + } + + void makeBlockUnreachableIfNoFallThrough(Block* curr) { + if (curr->type == unreachable) { + return; // no change possible + } + if (!curr->list.empty() && + isConcreteWasmType(curr->list.back()->type)) { + return; // should keep type due to fallthrough, even if has an unreachable child + } + for (auto* child : curr->list) { + if (child->type == unreachable) { + // no fallthrough, and an unreachable, => this block is now unreachable + changeTypeTo(curr, unreachable); + return; + } + } + } + + // efficiently update the type of an if, given the data we know. this + // can remove a concrete type and turn the if unreachable when it is + // unreachable + void maybeUpdateTypeToUnreachable(If* curr) { + if (!isConcreteWasmType(curr->type)) { + return; // nothing concrete to change to unreachable + } + curr->finalize(); + if (curr->type == unreachable) { + propagateTypesUp(curr); + } + } +}; + +} // namespace wasm + +#endif // wasm_ir_type_updating_h diff --git a/src/ir/utils.h b/src/ir/utils.h new file mode 100644 index 000000000..786e04e45 --- /dev/null +++ b/src/ir/utils.h @@ -0,0 +1,360 @@ +/* + * Copyright 2016 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ir_utils_h +#define wasm_ir_utils_h + +#include "wasm.h" +#include "wasm-traversal.h" +#include "wasm-builder.h" +#include "pass.h" +#include "ir/branch-utils.h" + +namespace wasm { + +// Measure the size of an AST + +struct Measurer : public PostWalker<Measurer, UnifiedExpressionVisitor<Measurer>> { + Index size = 0; + + void visitExpression(Expression* curr) { + size++; + } + + static Index measure(Expression* tree) { + Measurer measurer; + measurer.walk(tree); + return measurer.size; + } +}; + +struct ExpressionAnalyzer { + // Given a stack of expressions, checks if the topmost is used as a result. + // For example, if the parent is a block and the node is before the last position, + // it is not used. + static bool isResultUsed(std::vector<Expression*> stack, Function* func); + + // Checks if a value is dropped. + static bool isResultDropped(std::vector<Expression*> stack); + + // Checks if a break is a simple - no condition, no value, just a plain branching + static bool isSimple(Break* curr) { + return !curr->condition && !curr->value; + } + + using ExprComparer = std::function<bool(Expression*, Expression*)>; + static bool flexibleEqual(Expression* left, Expression* right, ExprComparer comparer); + + static bool equal(Expression* left, Expression* right) { + auto comparer = [](Expression* left, Expression* right) { + return false; + }; + return flexibleEqual(left, right, comparer); + } + + // hash an expression, ignoring superficial details like specific internal names + static uint32_t hash(Expression* curr); +}; + +// Re-Finalizes all node types +// This removes "unnecessary' block/if/loop types, i.e., that are added +// specifically, as in +// (block (result i32) (unreachable)) +// vs +// (block (unreachable)) +// This converts to the latter form. +struct ReFinalize : public WalkerPass<PostWalker<ReFinalize, OverriddenVisitor<ReFinalize>>> { + bool isFunctionParallel() override { return true; } + + Pass* create() override { return new ReFinalize; } + + ReFinalize() { name = "refinalize"; } + + // block finalization is O(bad) if we do each block by itself, so do it in bulk, + // tracking break value types so we just do a linear pass + + std::map<Name, WasmType> breakValues; + + void visitBlock(Block *curr) { + if (curr->list.size() == 0) { + curr->type = none; + return; + } + // do this quickly, without any validation + auto old = curr->type; + // last element determines type + curr->type = curr->list.back()->type; + // if concrete, it doesn't matter if we have an unreachable child, and we + // don't need to look at breaks + if (isConcreteWasmType(curr->type)) return; + // otherwise, we have no final fallthrough element to determine the type, + // could be determined by breaks + if (curr->name.is()) { + auto iter = breakValues.find(curr->name); + if (iter != breakValues.end()) { + // there is a break to here + auto type = iter->second; + if (type == unreachable) { + // all we have are breaks with values of type unreachable, and no + // concrete fallthrough either. we must have had an existing type, then + curr->type = old; + assert(isConcreteWasmType(curr->type)); + } else { + curr->type = type; + } + return; + } + } + if (curr->type == unreachable) return; + // type is none, but we might be unreachable + if (curr->type == none) { + for (auto* child : curr->list) { + if (child->type == unreachable) { + curr->type = unreachable; + break; + } + } + } + } + void visitIf(If *curr) { curr->finalize(); } + void visitLoop(Loop *curr) { curr->finalize(); } + void visitBreak(Break *curr) { + curr->finalize(); + updateBreakValueType(curr->name, getValueType(curr->value)); + } + void visitSwitch(Switch *curr) { + curr->finalize(); + auto valueType = getValueType(curr->value); + for (auto target : curr->targets) { + updateBreakValueType(target, valueType); + } + updateBreakValueType(curr->default_, valueType); + } + void visitCall(Call *curr) { curr->finalize(); } + void visitCallImport(CallImport *curr) { curr->finalize(); } + void visitCallIndirect(CallIndirect *curr) { curr->finalize(); } + void visitGetLocal(GetLocal *curr) { curr->finalize(); } + void visitSetLocal(SetLocal *curr) { curr->finalize(); } + void visitGetGlobal(GetGlobal *curr) { curr->finalize(); } + void visitSetGlobal(SetGlobal *curr) { curr->finalize(); } + void visitLoad(Load *curr) { curr->finalize(); } + void visitStore(Store *curr) { curr->finalize(); } + void visitAtomicRMW(AtomicRMW *curr) { curr->finalize(); } + void visitAtomicCmpxchg(AtomicCmpxchg *curr) { curr->finalize(); } + void visitAtomicWait(AtomicWait* curr) { curr->finalize(); } + void visitAtomicWake(AtomicWake* curr) { curr->finalize(); } + void visitConst(Const *curr) { curr->finalize(); } + void visitUnary(Unary *curr) { curr->finalize(); } + void visitBinary(Binary *curr) { curr->finalize(); } + void visitSelect(Select *curr) { curr->finalize(); } + void visitDrop(Drop *curr) { curr->finalize(); } + void visitReturn(Return *curr) { curr->finalize(); } + void visitHost(Host *curr) { curr->finalize(); } + void visitNop(Nop *curr) { curr->finalize(); } + void visitUnreachable(Unreachable *curr) { curr->finalize(); } + + void visitFunction(Function* curr) { + // we may have changed the body from unreachable to none, which might be bad + // if the function has a return value + if (curr->result != none && curr->body->type == none) { + Builder builder(*getModule()); + curr->body = builder.blockify(curr->body, builder.makeUnreachable()); + } + } + + void visitFunctionType(FunctionType* curr) { WASM_UNREACHABLE(); } + void visitImport(Import* curr) { WASM_UNREACHABLE(); } + void visitExport(Export* curr) { WASM_UNREACHABLE(); } + void visitGlobal(Global* curr) { WASM_UNREACHABLE(); } + void visitTable(Table* curr) { WASM_UNREACHABLE(); } + void visitMemory(Memory* curr) { WASM_UNREACHABLE(); } + void visitModule(Module* curr) { WASM_UNREACHABLE(); } + + WasmType getValueType(Expression* value) { + return value ? value->type : none; + } + + void updateBreakValueType(Name name, WasmType type) { + if (type != unreachable || breakValues.count(name) == 0) { + breakValues[name] = type; + } + } +}; + +// Re-finalize a single node. This is slow, if you want to refinalize +// an entire ast, use ReFinalize +struct ReFinalizeNode : public OverriddenVisitor<ReFinalizeNode> { + void visitBlock(Block *curr) { curr->finalize(); } + void visitIf(If *curr) { curr->finalize(); } + void visitLoop(Loop *curr) { curr->finalize(); } + void visitBreak(Break *curr) { curr->finalize(); } + void visitSwitch(Switch *curr) { curr->finalize(); } + void visitCall(Call *curr) { curr->finalize(); } + void visitCallImport(CallImport *curr) { curr->finalize(); } + void visitCallIndirect(CallIndirect *curr) { curr->finalize(); } + void visitGetLocal(GetLocal *curr) { curr->finalize(); } + void visitSetLocal(SetLocal *curr) { curr->finalize(); } + void visitGetGlobal(GetGlobal *curr) { curr->finalize(); } + void visitSetGlobal(SetGlobal *curr) { curr->finalize(); } + void visitLoad(Load *curr) { curr->finalize(); } + void visitStore(Store *curr) { curr->finalize(); } + void visitAtomicRMW(AtomicRMW* curr) { curr->finalize(); } + void visitAtomicCmpxchg(AtomicCmpxchg* curr) { curr->finalize(); } + void visitAtomicWait(AtomicWait* curr) { curr->finalize(); } + void visitAtomicWake(AtomicWake* curr) { curr->finalize(); } + void visitConst(Const *curr) { curr->finalize(); } + void visitUnary(Unary *curr) { curr->finalize(); } + void visitBinary(Binary *curr) { curr->finalize(); } + void visitSelect(Select *curr) { curr->finalize(); } + void visitDrop(Drop *curr) { curr->finalize(); } + void visitReturn(Return *curr) { curr->finalize(); } + void visitHost(Host *curr) { curr->finalize(); } + void visitNop(Nop *curr) { curr->finalize(); } + void visitUnreachable(Unreachable *curr) { curr->finalize(); } + + void visitFunctionType(FunctionType* curr) { WASM_UNREACHABLE(); } + void visitImport(Import* curr) { WASM_UNREACHABLE(); } + void visitExport(Export* curr) { WASM_UNREACHABLE(); } + void visitGlobal(Global* curr) { WASM_UNREACHABLE(); } + void visitTable(Table* curr) { WASM_UNREACHABLE(); } + void visitMemory(Memory* curr) { WASM_UNREACHABLE(); } + void visitModule(Module* curr) { WASM_UNREACHABLE(); } + + // given a stack of nested expressions, update them all from child to parent + static void updateStack(std::vector<Expression*>& expressionStack) { + for (int i = int(expressionStack.size()) - 1; i >= 0; i--) { + auto* curr = expressionStack[i]; + ReFinalizeNode().visit(curr); + } + } +}; + +// Adds drop() operations where necessary. This lets you not worry about adding drop when +// generating code. +// This also refinalizes before and after, as dropping can change types, and depends +// on types being cleaned up - no unnecessary block/if/loop types (see refinalize) +// TODO: optimize that, interleave them +struct AutoDrop : public WalkerPass<ExpressionStackWalker<AutoDrop>> { + bool isFunctionParallel() override { return true; } + + Pass* create() override { return new AutoDrop; } + + AutoDrop() { name = "autodrop"; } + + bool maybeDrop(Expression*& child) { + bool acted = false; + if (isConcreteWasmType(child->type)) { + expressionStack.push_back(child); + if (!ExpressionAnalyzer::isResultUsed(expressionStack, getFunction()) && !ExpressionAnalyzer::isResultDropped(expressionStack)) { + child = Builder(*getModule()).makeDrop(child); + acted = true; + } + expressionStack.pop_back(); + } + return acted; + } + + void reFinalize() { + ReFinalizeNode::updateStack(expressionStack); + } + + void visitBlock(Block* curr) { + if (curr->list.size() == 0) return; + for (Index i = 0; i < curr->list.size() - 1; i++) { + auto* child = curr->list[i]; + if (isConcreteWasmType(child->type)) { + curr->list[i] = Builder(*getModule()).makeDrop(child); + } + } + if (maybeDrop(curr->list.back())) { + reFinalize(); + assert(curr->type == none || curr->type == unreachable); + } + } + + void visitIf(If* curr) { + bool acted = false; + if (maybeDrop(curr->ifTrue)) acted = true; + if (curr->ifFalse) { + if (maybeDrop(curr->ifFalse)) acted = true; + } + if (acted) { + reFinalize(); + assert(curr->type == none); + } + } + + void doWalkFunction(Function* curr) { + ReFinalize().walkFunctionInModule(curr, getModule()); + walk(curr->body); + if (curr->result == none && isConcreteWasmType(curr->body->type)) { + curr->body = Builder(*getModule()).makeDrop(curr->body); + } + ReFinalize().walkFunctionInModule(curr, getModule()); + } +}; + +struct I64Utilities { + static Expression* recreateI64(Builder& builder, Expression* low, Expression* high) { + return + builder.makeBinary( + OrInt64, + builder.makeUnary( + ExtendUInt32, + low + ), + builder.makeBinary( + ShlInt64, + builder.makeUnary( + ExtendUInt32, + high + ), + builder.makeConst(Literal(int64_t(32))) + ) + ) + ; + }; + + static Expression* recreateI64(Builder& builder, Index low, Index high) { + return recreateI64(builder, builder.makeGetLocal(low, i32), builder.makeGetLocal(high, i32)); + }; + + static Expression* getI64High(Builder& builder, Index index) { + return + builder.makeUnary( + WrapInt64, + builder.makeBinary( + ShrUInt64, + builder.makeGetLocal(index, i64), + builder.makeConst(Literal(int64_t(32))) + ) + ) + ; + } + + static Expression* getI64Low(Builder& builder, Index index) { + return + builder.makeUnary( + WrapInt64, + builder.makeGetLocal(index, i64) + ) + ; + } +}; + +} // namespace wasm + +#endif // wasm_ir_utils_h |