summaryrefslogtreecommitdiff
path: root/src/ir/utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/ir/utils.h')
-rw-r--r--src/ir/utils.h360
1 files changed, 360 insertions, 0 deletions
diff --git a/src/ir/utils.h b/src/ir/utils.h
new file mode 100644
index 000000000..786e04e45
--- /dev/null
+++ b/src/ir/utils.h
@@ -0,0 +1,360 @@
+/*
+ * Copyright 2016 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef wasm_ir_utils_h
+#define wasm_ir_utils_h
+
+#include "wasm.h"
+#include "wasm-traversal.h"
+#include "wasm-builder.h"
+#include "pass.h"
+#include "ir/branch-utils.h"
+
+namespace wasm {
+
+// Measure the size of an AST
+
+struct Measurer : public PostWalker<Measurer, UnifiedExpressionVisitor<Measurer>> {
+ Index size = 0;
+
+ void visitExpression(Expression* curr) {
+ size++;
+ }
+
+ static Index measure(Expression* tree) {
+ Measurer measurer;
+ measurer.walk(tree);
+ return measurer.size;
+ }
+};
+
+struct ExpressionAnalyzer {
+ // Given a stack of expressions, checks if the topmost is used as a result.
+ // For example, if the parent is a block and the node is before the last position,
+ // it is not used.
+ static bool isResultUsed(std::vector<Expression*> stack, Function* func);
+
+ // Checks if a value is dropped.
+ static bool isResultDropped(std::vector<Expression*> stack);
+
+ // Checks if a break is a simple - no condition, no value, just a plain branching
+ static bool isSimple(Break* curr) {
+ return !curr->condition && !curr->value;
+ }
+
+ using ExprComparer = std::function<bool(Expression*, Expression*)>;
+ static bool flexibleEqual(Expression* left, Expression* right, ExprComparer comparer);
+
+ static bool equal(Expression* left, Expression* right) {
+ auto comparer = [](Expression* left, Expression* right) {
+ return false;
+ };
+ return flexibleEqual(left, right, comparer);
+ }
+
+ // hash an expression, ignoring superficial details like specific internal names
+ static uint32_t hash(Expression* curr);
+};
+
+// Re-Finalizes all node types
+// This removes "unnecessary' block/if/loop types, i.e., that are added
+// specifically, as in
+// (block (result i32) (unreachable))
+// vs
+// (block (unreachable))
+// This converts to the latter form.
+struct ReFinalize : public WalkerPass<PostWalker<ReFinalize, OverriddenVisitor<ReFinalize>>> {
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override { return new ReFinalize; }
+
+ ReFinalize() { name = "refinalize"; }
+
+ // block finalization is O(bad) if we do each block by itself, so do it in bulk,
+ // tracking break value types so we just do a linear pass
+
+ std::map<Name, WasmType> breakValues;
+
+ void visitBlock(Block *curr) {
+ if (curr->list.size() == 0) {
+ curr->type = none;
+ return;
+ }
+ // do this quickly, without any validation
+ auto old = curr->type;
+ // last element determines type
+ curr->type = curr->list.back()->type;
+ // if concrete, it doesn't matter if we have an unreachable child, and we
+ // don't need to look at breaks
+ if (isConcreteWasmType(curr->type)) return;
+ // otherwise, we have no final fallthrough element to determine the type,
+ // could be determined by breaks
+ if (curr->name.is()) {
+ auto iter = breakValues.find(curr->name);
+ if (iter != breakValues.end()) {
+ // there is a break to here
+ auto type = iter->second;
+ if (type == unreachable) {
+ // all we have are breaks with values of type unreachable, and no
+ // concrete fallthrough either. we must have had an existing type, then
+ curr->type = old;
+ assert(isConcreteWasmType(curr->type));
+ } else {
+ curr->type = type;
+ }
+ return;
+ }
+ }
+ if (curr->type == unreachable) return;
+ // type is none, but we might be unreachable
+ if (curr->type == none) {
+ for (auto* child : curr->list) {
+ if (child->type == unreachable) {
+ curr->type = unreachable;
+ break;
+ }
+ }
+ }
+ }
+ void visitIf(If *curr) { curr->finalize(); }
+ void visitLoop(Loop *curr) { curr->finalize(); }
+ void visitBreak(Break *curr) {
+ curr->finalize();
+ updateBreakValueType(curr->name, getValueType(curr->value));
+ }
+ void visitSwitch(Switch *curr) {
+ curr->finalize();
+ auto valueType = getValueType(curr->value);
+ for (auto target : curr->targets) {
+ updateBreakValueType(target, valueType);
+ }
+ updateBreakValueType(curr->default_, valueType);
+ }
+ void visitCall(Call *curr) { curr->finalize(); }
+ void visitCallImport(CallImport *curr) { curr->finalize(); }
+ void visitCallIndirect(CallIndirect *curr) { curr->finalize(); }
+ void visitGetLocal(GetLocal *curr) { curr->finalize(); }
+ void visitSetLocal(SetLocal *curr) { curr->finalize(); }
+ void visitGetGlobal(GetGlobal *curr) { curr->finalize(); }
+ void visitSetGlobal(SetGlobal *curr) { curr->finalize(); }
+ void visitLoad(Load *curr) { curr->finalize(); }
+ void visitStore(Store *curr) { curr->finalize(); }
+ void visitAtomicRMW(AtomicRMW *curr) { curr->finalize(); }
+ void visitAtomicCmpxchg(AtomicCmpxchg *curr) { curr->finalize(); }
+ void visitAtomicWait(AtomicWait* curr) { curr->finalize(); }
+ void visitAtomicWake(AtomicWake* curr) { curr->finalize(); }
+ void visitConst(Const *curr) { curr->finalize(); }
+ void visitUnary(Unary *curr) { curr->finalize(); }
+ void visitBinary(Binary *curr) { curr->finalize(); }
+ void visitSelect(Select *curr) { curr->finalize(); }
+ void visitDrop(Drop *curr) { curr->finalize(); }
+ void visitReturn(Return *curr) { curr->finalize(); }
+ void visitHost(Host *curr) { curr->finalize(); }
+ void visitNop(Nop *curr) { curr->finalize(); }
+ void visitUnreachable(Unreachable *curr) { curr->finalize(); }
+
+ void visitFunction(Function* curr) {
+ // we may have changed the body from unreachable to none, which might be bad
+ // if the function has a return value
+ if (curr->result != none && curr->body->type == none) {
+ Builder builder(*getModule());
+ curr->body = builder.blockify(curr->body, builder.makeUnreachable());
+ }
+ }
+
+ void visitFunctionType(FunctionType* curr) { WASM_UNREACHABLE(); }
+ void visitImport(Import* curr) { WASM_UNREACHABLE(); }
+ void visitExport(Export* curr) { WASM_UNREACHABLE(); }
+ void visitGlobal(Global* curr) { WASM_UNREACHABLE(); }
+ void visitTable(Table* curr) { WASM_UNREACHABLE(); }
+ void visitMemory(Memory* curr) { WASM_UNREACHABLE(); }
+ void visitModule(Module* curr) { WASM_UNREACHABLE(); }
+
+ WasmType getValueType(Expression* value) {
+ return value ? value->type : none;
+ }
+
+ void updateBreakValueType(Name name, WasmType type) {
+ if (type != unreachable || breakValues.count(name) == 0) {
+ breakValues[name] = type;
+ }
+ }
+};
+
+// Re-finalize a single node. This is slow, if you want to refinalize
+// an entire ast, use ReFinalize
+struct ReFinalizeNode : public OverriddenVisitor<ReFinalizeNode> {
+ void visitBlock(Block *curr) { curr->finalize(); }
+ void visitIf(If *curr) { curr->finalize(); }
+ void visitLoop(Loop *curr) { curr->finalize(); }
+ void visitBreak(Break *curr) { curr->finalize(); }
+ void visitSwitch(Switch *curr) { curr->finalize(); }
+ void visitCall(Call *curr) { curr->finalize(); }
+ void visitCallImport(CallImport *curr) { curr->finalize(); }
+ void visitCallIndirect(CallIndirect *curr) { curr->finalize(); }
+ void visitGetLocal(GetLocal *curr) { curr->finalize(); }
+ void visitSetLocal(SetLocal *curr) { curr->finalize(); }
+ void visitGetGlobal(GetGlobal *curr) { curr->finalize(); }
+ void visitSetGlobal(SetGlobal *curr) { curr->finalize(); }
+ void visitLoad(Load *curr) { curr->finalize(); }
+ void visitStore(Store *curr) { curr->finalize(); }
+ void visitAtomicRMW(AtomicRMW* curr) { curr->finalize(); }
+ void visitAtomicCmpxchg(AtomicCmpxchg* curr) { curr->finalize(); }
+ void visitAtomicWait(AtomicWait* curr) { curr->finalize(); }
+ void visitAtomicWake(AtomicWake* curr) { curr->finalize(); }
+ void visitConst(Const *curr) { curr->finalize(); }
+ void visitUnary(Unary *curr) { curr->finalize(); }
+ void visitBinary(Binary *curr) { curr->finalize(); }
+ void visitSelect(Select *curr) { curr->finalize(); }
+ void visitDrop(Drop *curr) { curr->finalize(); }
+ void visitReturn(Return *curr) { curr->finalize(); }
+ void visitHost(Host *curr) { curr->finalize(); }
+ void visitNop(Nop *curr) { curr->finalize(); }
+ void visitUnreachable(Unreachable *curr) { curr->finalize(); }
+
+ void visitFunctionType(FunctionType* curr) { WASM_UNREACHABLE(); }
+ void visitImport(Import* curr) { WASM_UNREACHABLE(); }
+ void visitExport(Export* curr) { WASM_UNREACHABLE(); }
+ void visitGlobal(Global* curr) { WASM_UNREACHABLE(); }
+ void visitTable(Table* curr) { WASM_UNREACHABLE(); }
+ void visitMemory(Memory* curr) { WASM_UNREACHABLE(); }
+ void visitModule(Module* curr) { WASM_UNREACHABLE(); }
+
+ // given a stack of nested expressions, update them all from child to parent
+ static void updateStack(std::vector<Expression*>& expressionStack) {
+ for (int i = int(expressionStack.size()) - 1; i >= 0; i--) {
+ auto* curr = expressionStack[i];
+ ReFinalizeNode().visit(curr);
+ }
+ }
+};
+
+// Adds drop() operations where necessary. This lets you not worry about adding drop when
+// generating code.
+// This also refinalizes before and after, as dropping can change types, and depends
+// on types being cleaned up - no unnecessary block/if/loop types (see refinalize)
+// TODO: optimize that, interleave them
+struct AutoDrop : public WalkerPass<ExpressionStackWalker<AutoDrop>> {
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override { return new AutoDrop; }
+
+ AutoDrop() { name = "autodrop"; }
+
+ bool maybeDrop(Expression*& child) {
+ bool acted = false;
+ if (isConcreteWasmType(child->type)) {
+ expressionStack.push_back(child);
+ if (!ExpressionAnalyzer::isResultUsed(expressionStack, getFunction()) && !ExpressionAnalyzer::isResultDropped(expressionStack)) {
+ child = Builder(*getModule()).makeDrop(child);
+ acted = true;
+ }
+ expressionStack.pop_back();
+ }
+ return acted;
+ }
+
+ void reFinalize() {
+ ReFinalizeNode::updateStack(expressionStack);
+ }
+
+ void visitBlock(Block* curr) {
+ if (curr->list.size() == 0) return;
+ for (Index i = 0; i < curr->list.size() - 1; i++) {
+ auto* child = curr->list[i];
+ if (isConcreteWasmType(child->type)) {
+ curr->list[i] = Builder(*getModule()).makeDrop(child);
+ }
+ }
+ if (maybeDrop(curr->list.back())) {
+ reFinalize();
+ assert(curr->type == none || curr->type == unreachable);
+ }
+ }
+
+ void visitIf(If* curr) {
+ bool acted = false;
+ if (maybeDrop(curr->ifTrue)) acted = true;
+ if (curr->ifFalse) {
+ if (maybeDrop(curr->ifFalse)) acted = true;
+ }
+ if (acted) {
+ reFinalize();
+ assert(curr->type == none);
+ }
+ }
+
+ void doWalkFunction(Function* curr) {
+ ReFinalize().walkFunctionInModule(curr, getModule());
+ walk(curr->body);
+ if (curr->result == none && isConcreteWasmType(curr->body->type)) {
+ curr->body = Builder(*getModule()).makeDrop(curr->body);
+ }
+ ReFinalize().walkFunctionInModule(curr, getModule());
+ }
+};
+
+struct I64Utilities {
+ static Expression* recreateI64(Builder& builder, Expression* low, Expression* high) {
+ return
+ builder.makeBinary(
+ OrInt64,
+ builder.makeUnary(
+ ExtendUInt32,
+ low
+ ),
+ builder.makeBinary(
+ ShlInt64,
+ builder.makeUnary(
+ ExtendUInt32,
+ high
+ ),
+ builder.makeConst(Literal(int64_t(32)))
+ )
+ )
+ ;
+ };
+
+ static Expression* recreateI64(Builder& builder, Index low, Index high) {
+ return recreateI64(builder, builder.makeGetLocal(low, i32), builder.makeGetLocal(high, i32));
+ };
+
+ static Expression* getI64High(Builder& builder, Index index) {
+ return
+ builder.makeUnary(
+ WrapInt64,
+ builder.makeBinary(
+ ShrUInt64,
+ builder.makeGetLocal(index, i64),
+ builder.makeConst(Literal(int64_t(32)))
+ )
+ )
+ ;
+ }
+
+ static Expression* getI64Low(Builder& builder, Index index) {
+ return
+ builder.makeUnary(
+ WrapInt64,
+ builder.makeGetLocal(index, i64)
+ )
+ ;
+ }
+};
+
+} // namespace wasm
+
+#endif // wasm_ir_utils_h