diff options
Diffstat (limited to 'src/ast_utils.h')
-rw-r--r-- | src/ast_utils.h | 111 |
1 files changed, 95 insertions, 16 deletions
diff --git a/src/ast_utils.h b/src/ast_utils.h index 3e45d0e33..9b2ff10cd 100644 --- a/src/ast_utils.h +++ b/src/ast_utils.h @@ -21,6 +21,7 @@ #include "wasm.h" #include "wasm-traversal.h" #include "wasm-builder.h" +#include "pass.h" namespace wasm { @@ -129,6 +130,9 @@ struct EffectAnalyzer : public PostWalker<EffectAnalyzer, Visitor<EffectAnalyzer bool checkPost(Expression* curr) { visit(curr); + if (curr->is<Loop>()) { + branches = true; + } return hasAnything(); } @@ -147,8 +151,7 @@ struct EffectAnalyzer : public PostWalker<EffectAnalyzer, Visitor<EffectAnalyzer if (curr->name.is()) breakNames.erase(curr->name); // these were internal breaks } void visitLoop(Loop* curr) { - if (curr->in.is()) breakNames.erase(curr->in); // these were internal breaks - if (curr->out.is()) breakNames.erase(curr->out); // these were internal breaks + if (curr->name.is()) breakNames.erase(curr->name); // these were internal breaks } void visitCall(Call *curr) { calls = true; } @@ -244,7 +247,7 @@ struct ExpressionManipulator { return builder.makeIf(copy(curr->condition), copy(curr->ifTrue), copy(curr->ifFalse)); } Expression* visitLoop(Loop *curr) { - return builder.makeLoop(curr->out, curr->in, copy(curr->body)); + return builder.makeLoop(curr->name, copy(curr->body)); } Expression* visitBreak(Break *curr) { return builder.makeBreak(curr->name, copy(curr->value), copy(curr->condition)); @@ -277,19 +280,23 @@ struct ExpressionManipulator { return builder.makeGetLocal(curr->index, curr->type); } Expression* visitSetLocal(SetLocal *curr) { - return builder.makeSetLocal(curr->index, copy(curr->value)); + if (curr->isTee()) { + return builder.makeTeeLocal(curr->index, copy(curr->value)); + } else { + return builder.makeSetLocal(curr->index, copy(curr->value)); + } } Expression* visitGetGlobal(GetGlobal *curr) { - return builder.makeGetGlobal(curr->index, curr->type); + return builder.makeGetGlobal(curr->name, curr->type); } Expression* visitSetGlobal(SetGlobal *curr) { - return builder.makeSetGlobal(curr->index, copy(curr->value)); + return builder.makeSetGlobal(curr->name, copy(curr->value)); } Expression* visitLoad(Load *curr) { return builder.makeLoad(curr->bytes, curr->signed_, curr->offset, curr->align, copy(curr->ptr), curr->type); } Expression* visitStore(Store *curr) { - return builder.makeStore(curr->bytes, curr->offset, curr->align, copy(curr->ptr), copy(curr->value)); + return builder.makeStore(curr->bytes, curr->offset, curr->align, copy(curr->ptr), copy(curr->value), curr->valueType); } Expression* visitConst(Const *curr) { return builder.makeConst(curr->value); @@ -303,6 +310,9 @@ struct ExpressionManipulator { Expression* visitSelect(Select *curr) { return builder.makeSelect(copy(curr->condition), copy(curr->ifTrue), copy(curr->ifFalse)); } + Expression* visitDrop(Drop *curr) { + return builder.makeDrop(copy(curr->value)); + } Expression* visitReturn(Return *curr) { return builder.makeReturn(copy(curr->value)); } @@ -340,7 +350,7 @@ struct ExpressionAnalyzer { for (int i = int(stack.size()) - 2; i >= 0; i--) { auto* curr = stack[i]; auto* above = stack[i + 1]; - // only if and block can drop values + // only if and block can drop values (pre-drop expression was added) FIXME if (curr->is<Block>()) { auto* block = curr->cast<Block>(); for (size_t j = 0; j < block->list.size() - 1; j++) { @@ -355,6 +365,7 @@ struct ExpressionAnalyzer { assert(above == iff->ifTrue || above == iff->ifFalse); // continue down } else { + if (curr->is<Drop>()) return false; return true; // all other node types use the result } } @@ -429,8 +440,7 @@ struct ExpressionAnalyzer { break; } case Expression::Id::LoopId: { - if (!noteNames(left->cast<Loop>()->out, right->cast<Loop>()->out)) return false; - if (!noteNames(left->cast<Loop>()->in, right->cast<Loop>()->in)) return false; + if (!noteNames(left->cast<Loop>()->name, right->cast<Loop>()->name)) return false; PUSH(Loop, body); break; } @@ -481,15 +491,16 @@ struct ExpressionAnalyzer { } case Expression::Id::SetLocalId: { CHECK(SetLocal, index); + CHECK(SetLocal, type); // for tee/set PUSH(SetLocal, value); break; } case Expression::Id::GetGlobalId: { - CHECK(GetGlobal, index); + CHECK(GetGlobal, name); break; } case Expression::Id::SetGlobalId: { - CHECK(SetGlobal, index); + CHECK(SetGlobal, name); PUSH(SetGlobal, value); break; } @@ -505,6 +516,7 @@ struct ExpressionAnalyzer { CHECK(Store, bytes); CHECK(Store, offset); CHECK(Store, align); + CHECK(Store, valueType); PUSH(Store, ptr); PUSH(Store, value); break; @@ -530,6 +542,10 @@ struct ExpressionAnalyzer { PUSH(Select, condition); break; } + case Expression::Id::DropId: { + PUSH(Drop, value); + break; + } case Expression::Id::ReturnId: { PUSH(Return, value); break; @@ -640,8 +656,7 @@ struct ExpressionAnalyzer { break; } case Expression::Id::LoopId: { - noteName(curr->cast<Loop>()->out); - noteName(curr->cast<Loop>()->in); + noteName(curr->cast<Loop>()->name); PUSH(Loop, body); break; } @@ -696,11 +711,11 @@ struct ExpressionAnalyzer { break; } case Expression::Id::GetGlobalId: { - HASH(GetGlobal, index); + HASH_NAME(GetGlobal, name); break; } case Expression::Id::SetGlobalId: { - HASH(SetGlobal, index); + HASH_NAME(SetGlobal, name); PUSH(SetGlobal, value); break; } @@ -716,6 +731,7 @@ struct ExpressionAnalyzer { HASH(Store, bytes); HASH(Store, offset); HASH(Store, align); + HASH(Store, valueType); PUSH(Store, ptr); PUSH(Store, value); break; @@ -742,6 +758,10 @@ struct ExpressionAnalyzer { PUSH(Select, condition); break; } + case Expression::Id::DropId: { + PUSH(Drop, value); + break; + } case Expression::Id::ReturnId: { PUSH(Return, value); break; @@ -770,6 +790,65 @@ struct ExpressionAnalyzer { } }; +// Adds drop() operations where necessary. This lets you not worry about adding drop when +// generating code. +struct AutoDrop : public WalkerPass<ExpressionStackWalker<AutoDrop, Visitor<AutoDrop>>> { + bool isFunctionParallel() override { return true; } + + Pass* create() override { return new AutoDrop; } + + void visitBlock(Block* curr) { + if (curr->list.size() == 0) return; + for (Index i = 0; i < curr->list.size() - 1; i++) { + auto* child = curr->list[i]; + if (isConcreteWasmType(child->type)) { + curr->list[i] = Builder(*getModule()).makeDrop(child); + } + } + auto* last = curr->list.back(); + expressionStack.push_back(last); + if (isConcreteWasmType(last->type) && !ExpressionAnalyzer::isResultUsed(expressionStack, getFunction())) { + curr->list.back() = Builder(*getModule()).makeDrop(last); + } + expressionStack.pop_back(); + curr->finalize(); // we may have changed our type + } + + void visitFunction(Function* curr) { + if (curr->result == none && isConcreteWasmType(curr->body->type)) { + curr->body = Builder(*getModule()).makeDrop(curr->body); + } + } +}; + +// Finalizes a node + +struct ReFinalize : public WalkerPass<PostWalker<ReFinalize, Visitor<ReFinalize>>> { + void visitBlock(Block *curr) { curr->finalize(); } + void visitIf(If *curr) { curr->finalize(); } + void visitLoop(Loop *curr) { curr->finalize(); } + void visitBreak(Break *curr) { curr->finalize(); } + void visitSwitch(Switch *curr) { curr->finalize(); } + void visitCall(Call *curr) { curr->finalize(); } + void visitCallImport(CallImport *curr) { curr->finalize(); } + void visitCallIndirect(CallIndirect *curr) { curr->finalize(); } + void visitGetLocal(GetLocal *curr) { curr->finalize(); } + void visitSetLocal(SetLocal *curr) { curr->finalize(); } + void visitGetGlobal(GetGlobal *curr) { curr->finalize(); } + void visitSetGlobal(SetGlobal *curr) { curr->finalize(); } + void visitLoad(Load *curr) { curr->finalize(); } + void visitStore(Store *curr) { curr->finalize(); } + void visitConst(Const *curr) { curr->finalize(); } + void visitUnary(Unary *curr) { curr->finalize(); } + void visitBinary(Binary *curr) { curr->finalize(); } + void visitSelect(Select *curr) { curr->finalize(); } + void visitDrop(Drop *curr) { curr->finalize(); } + void visitReturn(Return *curr) { curr->finalize(); } + void visitHost(Host *curr) { curr->finalize(); } + void visitNop(Nop *curr) { curr->finalize(); } + void visitUnreachable(Unreachable *curr) { curr->finalize(); } +}; + } // namespace wasm #endif // wasm_ast_utils_h |