summaryrefslogtreecommitdiff
path: root/src/ast_utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/ast_utils.h')
-rw-r--r--src/ast_utils.h111
1 files changed, 95 insertions, 16 deletions
diff --git a/src/ast_utils.h b/src/ast_utils.h
index 3e45d0e33..9b2ff10cd 100644
--- a/src/ast_utils.h
+++ b/src/ast_utils.h
@@ -21,6 +21,7 @@
#include "wasm.h"
#include "wasm-traversal.h"
#include "wasm-builder.h"
+#include "pass.h"
namespace wasm {
@@ -129,6 +130,9 @@ struct EffectAnalyzer : public PostWalker<EffectAnalyzer, Visitor<EffectAnalyzer
bool checkPost(Expression* curr) {
visit(curr);
+ if (curr->is<Loop>()) {
+ branches = true;
+ }
return hasAnything();
}
@@ -147,8 +151,7 @@ struct EffectAnalyzer : public PostWalker<EffectAnalyzer, Visitor<EffectAnalyzer
if (curr->name.is()) breakNames.erase(curr->name); // these were internal breaks
}
void visitLoop(Loop* curr) {
- if (curr->in.is()) breakNames.erase(curr->in); // these were internal breaks
- if (curr->out.is()) breakNames.erase(curr->out); // these were internal breaks
+ if (curr->name.is()) breakNames.erase(curr->name); // these were internal breaks
}
void visitCall(Call *curr) { calls = true; }
@@ -244,7 +247,7 @@ struct ExpressionManipulator {
return builder.makeIf(copy(curr->condition), copy(curr->ifTrue), copy(curr->ifFalse));
}
Expression* visitLoop(Loop *curr) {
- return builder.makeLoop(curr->out, curr->in, copy(curr->body));
+ return builder.makeLoop(curr->name, copy(curr->body));
}
Expression* visitBreak(Break *curr) {
return builder.makeBreak(curr->name, copy(curr->value), copy(curr->condition));
@@ -277,19 +280,23 @@ struct ExpressionManipulator {
return builder.makeGetLocal(curr->index, curr->type);
}
Expression* visitSetLocal(SetLocal *curr) {
- return builder.makeSetLocal(curr->index, copy(curr->value));
+ if (curr->isTee()) {
+ return builder.makeTeeLocal(curr->index, copy(curr->value));
+ } else {
+ return builder.makeSetLocal(curr->index, copy(curr->value));
+ }
}
Expression* visitGetGlobal(GetGlobal *curr) {
- return builder.makeGetGlobal(curr->index, curr->type);
+ return builder.makeGetGlobal(curr->name, curr->type);
}
Expression* visitSetGlobal(SetGlobal *curr) {
- return builder.makeSetGlobal(curr->index, copy(curr->value));
+ return builder.makeSetGlobal(curr->name, copy(curr->value));
}
Expression* visitLoad(Load *curr) {
return builder.makeLoad(curr->bytes, curr->signed_, curr->offset, curr->align, copy(curr->ptr), curr->type);
}
Expression* visitStore(Store *curr) {
- return builder.makeStore(curr->bytes, curr->offset, curr->align, copy(curr->ptr), copy(curr->value));
+ return builder.makeStore(curr->bytes, curr->offset, curr->align, copy(curr->ptr), copy(curr->value), curr->valueType);
}
Expression* visitConst(Const *curr) {
return builder.makeConst(curr->value);
@@ -303,6 +310,9 @@ struct ExpressionManipulator {
Expression* visitSelect(Select *curr) {
return builder.makeSelect(copy(curr->condition), copy(curr->ifTrue), copy(curr->ifFalse));
}
+ Expression* visitDrop(Drop *curr) {
+ return builder.makeDrop(copy(curr->value));
+ }
Expression* visitReturn(Return *curr) {
return builder.makeReturn(copy(curr->value));
}
@@ -340,7 +350,7 @@ struct ExpressionAnalyzer {
for (int i = int(stack.size()) - 2; i >= 0; i--) {
auto* curr = stack[i];
auto* above = stack[i + 1];
- // only if and block can drop values
+ // only if and block can drop values (pre-drop expression was added) FIXME
if (curr->is<Block>()) {
auto* block = curr->cast<Block>();
for (size_t j = 0; j < block->list.size() - 1; j++) {
@@ -355,6 +365,7 @@ struct ExpressionAnalyzer {
assert(above == iff->ifTrue || above == iff->ifFalse);
// continue down
} else {
+ if (curr->is<Drop>()) return false;
return true; // all other node types use the result
}
}
@@ -429,8 +440,7 @@ struct ExpressionAnalyzer {
break;
}
case Expression::Id::LoopId: {
- if (!noteNames(left->cast<Loop>()->out, right->cast<Loop>()->out)) return false;
- if (!noteNames(left->cast<Loop>()->in, right->cast<Loop>()->in)) return false;
+ if (!noteNames(left->cast<Loop>()->name, right->cast<Loop>()->name)) return false;
PUSH(Loop, body);
break;
}
@@ -481,15 +491,16 @@ struct ExpressionAnalyzer {
}
case Expression::Id::SetLocalId: {
CHECK(SetLocal, index);
+ CHECK(SetLocal, type); // for tee/set
PUSH(SetLocal, value);
break;
}
case Expression::Id::GetGlobalId: {
- CHECK(GetGlobal, index);
+ CHECK(GetGlobal, name);
break;
}
case Expression::Id::SetGlobalId: {
- CHECK(SetGlobal, index);
+ CHECK(SetGlobal, name);
PUSH(SetGlobal, value);
break;
}
@@ -505,6 +516,7 @@ struct ExpressionAnalyzer {
CHECK(Store, bytes);
CHECK(Store, offset);
CHECK(Store, align);
+ CHECK(Store, valueType);
PUSH(Store, ptr);
PUSH(Store, value);
break;
@@ -530,6 +542,10 @@ struct ExpressionAnalyzer {
PUSH(Select, condition);
break;
}
+ case Expression::Id::DropId: {
+ PUSH(Drop, value);
+ break;
+ }
case Expression::Id::ReturnId: {
PUSH(Return, value);
break;
@@ -640,8 +656,7 @@ struct ExpressionAnalyzer {
break;
}
case Expression::Id::LoopId: {
- noteName(curr->cast<Loop>()->out);
- noteName(curr->cast<Loop>()->in);
+ noteName(curr->cast<Loop>()->name);
PUSH(Loop, body);
break;
}
@@ -696,11 +711,11 @@ struct ExpressionAnalyzer {
break;
}
case Expression::Id::GetGlobalId: {
- HASH(GetGlobal, index);
+ HASH_NAME(GetGlobal, name);
break;
}
case Expression::Id::SetGlobalId: {
- HASH(SetGlobal, index);
+ HASH_NAME(SetGlobal, name);
PUSH(SetGlobal, value);
break;
}
@@ -716,6 +731,7 @@ struct ExpressionAnalyzer {
HASH(Store, bytes);
HASH(Store, offset);
HASH(Store, align);
+ HASH(Store, valueType);
PUSH(Store, ptr);
PUSH(Store, value);
break;
@@ -742,6 +758,10 @@ struct ExpressionAnalyzer {
PUSH(Select, condition);
break;
}
+ case Expression::Id::DropId: {
+ PUSH(Drop, value);
+ break;
+ }
case Expression::Id::ReturnId: {
PUSH(Return, value);
break;
@@ -770,6 +790,65 @@ struct ExpressionAnalyzer {
}
};
+// Adds drop() operations where necessary. This lets you not worry about adding drop when
+// generating code.
+struct AutoDrop : public WalkerPass<ExpressionStackWalker<AutoDrop, Visitor<AutoDrop>>> {
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override { return new AutoDrop; }
+
+ void visitBlock(Block* curr) {
+ if (curr->list.size() == 0) return;
+ for (Index i = 0; i < curr->list.size() - 1; i++) {
+ auto* child = curr->list[i];
+ if (isConcreteWasmType(child->type)) {
+ curr->list[i] = Builder(*getModule()).makeDrop(child);
+ }
+ }
+ auto* last = curr->list.back();
+ expressionStack.push_back(last);
+ if (isConcreteWasmType(last->type) && !ExpressionAnalyzer::isResultUsed(expressionStack, getFunction())) {
+ curr->list.back() = Builder(*getModule()).makeDrop(last);
+ }
+ expressionStack.pop_back();
+ curr->finalize(); // we may have changed our type
+ }
+
+ void visitFunction(Function* curr) {
+ if (curr->result == none && isConcreteWasmType(curr->body->type)) {
+ curr->body = Builder(*getModule()).makeDrop(curr->body);
+ }
+ }
+};
+
+// Finalizes a node
+
+struct ReFinalize : public WalkerPass<PostWalker<ReFinalize, Visitor<ReFinalize>>> {
+ void visitBlock(Block *curr) { curr->finalize(); }
+ void visitIf(If *curr) { curr->finalize(); }
+ void visitLoop(Loop *curr) { curr->finalize(); }
+ void visitBreak(Break *curr) { curr->finalize(); }
+ void visitSwitch(Switch *curr) { curr->finalize(); }
+ void visitCall(Call *curr) { curr->finalize(); }
+ void visitCallImport(CallImport *curr) { curr->finalize(); }
+ void visitCallIndirect(CallIndirect *curr) { curr->finalize(); }
+ void visitGetLocal(GetLocal *curr) { curr->finalize(); }
+ void visitSetLocal(SetLocal *curr) { curr->finalize(); }
+ void visitGetGlobal(GetGlobal *curr) { curr->finalize(); }
+ void visitSetGlobal(SetGlobal *curr) { curr->finalize(); }
+ void visitLoad(Load *curr) { curr->finalize(); }
+ void visitStore(Store *curr) { curr->finalize(); }
+ void visitConst(Const *curr) { curr->finalize(); }
+ void visitUnary(Unary *curr) { curr->finalize(); }
+ void visitBinary(Binary *curr) { curr->finalize(); }
+ void visitSelect(Select *curr) { curr->finalize(); }
+ void visitDrop(Drop *curr) { curr->finalize(); }
+ void visitReturn(Return *curr) { curr->finalize(); }
+ void visitHost(Host *curr) { curr->finalize(); }
+ void visitNop(Nop *curr) { curr->finalize(); }
+ void visitUnreachable(Unreachable *curr) { curr->finalize(); }
+};
+
} // namespace wasm
#endif // wasm_ast_utils_h