summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/asm2wasm.h28
-rw-r--r--src/ast_utils.h50
-rw-r--r--src/binaryen-c.cpp43
-rw-r--r--src/binaryen-c.h9
-rw-r--r--src/passes/CoalesceLocals.cpp12
-rw-r--r--src/passes/DeadCodeElimination.cpp21
-rw-r--r--src/passes/MergeBlocks.cpp19
-rw-r--r--src/passes/Print.cpp16
-rw-r--r--src/passes/SimplifyLocals.cpp42
-rw-r--r--src/passes/Vacuum.cpp8
-rw-r--r--src/s2wasm.h5
-rw-r--r--src/shell-interface.h2
-rw-r--r--src/wasm-binary.h45
-rw-r--r--src/wasm-builder.h24
-rw-r--r--src/wasm-interpreter.h12
-rw-r--r--src/wasm-js.cpp4
-rw-r--r--src/wasm-s-parser.h27
-rw-r--r--src/wasm-traversal.h13
-rw-r--r--src/wasm-validator.h124
-rw-r--r--src/wasm.h28
20 files changed, 421 insertions, 111 deletions
diff --git a/src/asm2wasm.h b/src/asm2wasm.h
index ce3c0749d..857145937 100644
--- a/src/asm2wasm.h
+++ b/src/asm2wasm.h
@@ -793,6 +793,7 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
};
PassRunner passRunner(&wasm);
passRunner.add<FinalizeCalls>(this);
+ passRunner.add<AutoDrop>();
passRunner.run();
// apply memory growth, if relevant
@@ -802,7 +803,7 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
wasm.addFunction(builder.makeFunction(
GROW_WASM_MEMORY,
{ { NEW_SIZE, i32 } },
- none,
+ i32,
{},
builder.makeHost(
GrowMemory,
@@ -1009,7 +1010,8 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
auto ret = allocator.alloc<SetLocal>();
ret->index = function->getLocalIndex(ast[2][1]->getIString());
ret->value = process(ast[3]);
- ret->type = ret->value->type;
+ ret->setTee(false);
+ ret->finalize();
return ret;
}
// global var, do a store to memory
@@ -1021,7 +1023,8 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
ret->align = ret->bytes;
ret->ptr = builder.makeConst(Literal(int32_t(global.address)));
ret->value = process(ast[3]);
- ret->type = global.type;
+ ret->valueType = global.type;
+ ret->finalize();
return ret;
} else if (ast[2][0] == SUB) {
Ref target = ast[2];
@@ -1035,10 +1038,11 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
ret->align = view.bytes;
ret->ptr = processUnshifted(target[2], view.bytes);
ret->value = process(ast[3]);
- ret->type = asmToWasmType(view.type);
- if (ret->type != ret->value->type) {
+ ret->valueType = asmToWasmType(view.type);
+ ret->finalize();
+ if (ret->valueType != ret->value->type) {
// in asm.js we have some implicit coercions that we must do explicitly here
- if (ret->type == f32 && ret->value->type == f64) {
+ if (ret->valueType == f32 && ret->value->type == f64) {
auto conv = allocator.alloc<Unary>();
conv->op = DemoteFloat64;
conv->value = ret->value;
@@ -1273,11 +1277,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
}
abort_on("bad unary", ast);
} else if (what == IF) {
- auto ret = allocator.alloc<If>();
- ret->condition = process(ast[1]);
- ret->ifTrue = process(ast[2]);
- ret->ifFalse = !!ast[3] ? process(ast[3]) : nullptr;
- return ret;
+ return builder.makeIf(process(ast[1]), process(ast[2]), !!ast[3] ? process(ast[3]) : nullptr);
} else if (what == CALL) {
if (ast[1][0] == NAME) {
IString name = ast[1][1]->getIString();
@@ -1330,9 +1330,10 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
// No wasm support, so use a temp local
ensureI32Temp();
auto set = allocator.alloc<SetLocal>();
+ set->setTee(false);
set->index = function->getLocalIndex(I32_TEMP);
set->value = value;
- set->type = i32;
+ set->finalize();
auto get = [&]() {
auto ret = allocator.alloc<GetLocal>();
ret->index = function->getLocalIndex(I32_TEMP);
@@ -1526,6 +1527,9 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
if (breakSeeker.found == 0) {
auto block = allocator.alloc<Block>();
block->list.push_back(child);
+ if (isConcreteWasmType(child->type)) {
+ block->list.push_back(builder.makeNop()); // ensure a nop at the end, so the block has guaranteed none type and no values fall through
+ }
block->name = stop;
block->finalize();
return block;
diff --git a/src/ast_utils.h b/src/ast_utils.h
index 3e45d0e33..17f5490a9 100644
--- a/src/ast_utils.h
+++ b/src/ast_utils.h
@@ -21,6 +21,7 @@
#include "wasm.h"
#include "wasm-traversal.h"
#include "wasm-builder.h"
+#include "pass.h"
namespace wasm {
@@ -277,7 +278,11 @@ struct ExpressionManipulator {
return builder.makeGetLocal(curr->index, curr->type);
}
Expression* visitSetLocal(SetLocal *curr) {
- return builder.makeSetLocal(curr->index, copy(curr->value));
+ if (curr->isTee()) {
+ return builder.makeTeeLocal(curr->index, copy(curr->value));
+ } else {
+ return builder.makeSetLocal(curr->index, copy(curr->value));
+ }
}
Expression* visitGetGlobal(GetGlobal *curr) {
return builder.makeGetGlobal(curr->index, curr->type);
@@ -289,7 +294,7 @@ struct ExpressionManipulator {
return builder.makeLoad(curr->bytes, curr->signed_, curr->offset, curr->align, copy(curr->ptr), curr->type);
}
Expression* visitStore(Store *curr) {
- return builder.makeStore(curr->bytes, curr->offset, curr->align, copy(curr->ptr), copy(curr->value));
+ return builder.makeStore(curr->bytes, curr->offset, curr->align, copy(curr->ptr), copy(curr->value), curr->valueType);
}
Expression* visitConst(Const *curr) {
return builder.makeConst(curr->value);
@@ -303,6 +308,9 @@ struct ExpressionManipulator {
Expression* visitSelect(Select *curr) {
return builder.makeSelect(copy(curr->condition), copy(curr->ifTrue), copy(curr->ifFalse));
}
+ Expression* visitDrop(Drop *curr) {
+ return builder.makeDrop(copy(curr->value));
+ }
Expression* visitReturn(Return *curr) {
return builder.makeReturn(copy(curr->value));
}
@@ -340,7 +348,7 @@ struct ExpressionAnalyzer {
for (int i = int(stack.size()) - 2; i >= 0; i--) {
auto* curr = stack[i];
auto* above = stack[i + 1];
- // only if and block can drop values
+ // only if and block can drop values (pre-drop expression was added) FIXME
if (curr->is<Block>()) {
auto* block = curr->cast<Block>();
for (size_t j = 0; j < block->list.size() - 1; j++) {
@@ -355,6 +363,7 @@ struct ExpressionAnalyzer {
assert(above == iff->ifTrue || above == iff->ifFalse);
// continue down
} else {
+ if (curr->is<Drop>()) return false;
return true; // all other node types use the result
}
}
@@ -481,6 +490,7 @@ struct ExpressionAnalyzer {
}
case Expression::Id::SetLocalId: {
CHECK(SetLocal, index);
+ CHECK(SetLocal, type); // for tee/set
PUSH(SetLocal, value);
break;
}
@@ -505,6 +515,7 @@ struct ExpressionAnalyzer {
CHECK(Store, bytes);
CHECK(Store, offset);
CHECK(Store, align);
+ CHECK(Store, valueType);
PUSH(Store, ptr);
PUSH(Store, value);
break;
@@ -530,6 +541,10 @@ struct ExpressionAnalyzer {
PUSH(Select, condition);
break;
}
+ case Expression::Id::DropId: {
+ PUSH(Drop, value);
+ break;
+ }
case Expression::Id::ReturnId: {
PUSH(Return, value);
break;
@@ -716,6 +731,7 @@ struct ExpressionAnalyzer {
HASH(Store, bytes);
HASH(Store, offset);
HASH(Store, align);
+ HASH(Store, valueType);
PUSH(Store, ptr);
PUSH(Store, value);
break;
@@ -742,6 +758,10 @@ struct ExpressionAnalyzer {
PUSH(Select, condition);
break;
}
+ case Expression::Id::DropId: {
+ PUSH(Drop, value);
+ break;
+ }
case Expression::Id::ReturnId: {
PUSH(Return, value);
break;
@@ -770,6 +790,30 @@ struct ExpressionAnalyzer {
}
};
+// Adds drop() operations where necessary. This lets you not worry about adding drop when
+// generating code.
+struct AutoDrop : public WalkerPass<PostWalker<AutoDrop, Visitor<AutoDrop>>> {
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override { return new AutoDrop; }
+
+ void visitBlock(Block* curr) {
+ if (curr->list.size() <= 1) return;
+ for (Index i = 0; i < curr->list.size() - 1; i++) {
+ auto* child = curr->list[i];
+ if (isConcreteWasmType(child->type)) {
+ curr->list[i] = Builder(*getModule()).makeDrop(child);
+ }
+ }
+ }
+
+ void visitFunction(Function* curr) {
+ if (curr->result == none && isConcreteWasmType(curr->body->type)) {
+ curr->body = Builder(*getModule()).makeDrop(curr->body);
+ }
+ }
+};
+
} // namespace wasm
#endif // wasm_ast_utils_h
diff --git a/src/binaryen-c.cpp b/src/binaryen-c.cpp
index 83c626eb1..71c6cad80 100644
--- a/src/binaryen-c.cpp
+++ b/src/binaryen-c.cpp
@@ -488,6 +488,21 @@ BinaryenExpressionRef BinaryenSetLocal(BinaryenModuleRef module, BinaryenIndex i
ret->index = index;
ret->value = (Expression*)value;
+ ret->setTee(false);
+ ret->finalize();
+ return static_cast<Expression*>(ret);
+}
+BinaryenExpressionRef BinaryenTeeLocal(BinaryenModuleRef module, BinaryenIndex index, BinaryenExpressionRef value) {
+ auto* ret = ((Module*)module)->allocator.alloc<SetLocal>();
+
+ if (tracing) {
+ auto id = noteExpression(ret);
+ std::cout << " expressions[" << id << "] = BinaryenTeeLocal(the_module, " << index << ", expressions[" << expressions[value] << "]);\n";
+ }
+
+ ret->index = index;
+ ret->value = (Expression*)value;
+ ret->setTee(true);
ret->finalize();
return static_cast<Expression*>(ret);
}
@@ -508,12 +523,12 @@ BinaryenExpressionRef BinaryenLoad(BinaryenModuleRef module, uint32_t bytes, int
ret->finalize();
return static_cast<Expression*>(ret);
}
-BinaryenExpressionRef BinaryenStore(BinaryenModuleRef module, uint32_t bytes, uint32_t offset, uint32_t align, BinaryenExpressionRef ptr, BinaryenExpressionRef value) {
+BinaryenExpressionRef BinaryenStore(BinaryenModuleRef module, uint32_t bytes, uint32_t offset, uint32_t align, BinaryenExpressionRef ptr, BinaryenExpressionRef value, BinaryenType type) {
auto* ret = ((Module*)module)->allocator.alloc<Store>();
if (tracing) {
auto id = noteExpression(ret);
- std::cout << " expressions[" << id << "] = BinaryenStore(the_module, " << bytes << ", " << offset << ", " << align << ", expressions[" << expressions[ptr] << "], expressions[" << expressions[value] << "]);\n";
+ std::cout << " expressions[" << id << "] = BinaryenStore(the_module, " << bytes << ", " << offset << ", " << align << ", expressions[" << expressions[ptr] << "], expressions[" << expressions[value] << "], " << type << ");\n";
}
ret->bytes = bytes;
@@ -521,6 +536,7 @@ BinaryenExpressionRef BinaryenStore(BinaryenModuleRef module, uint32_t bytes, ui
ret->align = align ? align : bytes;
ret->ptr = (Expression*)ptr;
ret->value = (Expression*)value;
+ ret->valueType = WasmType(type);
ret->finalize();
return static_cast<Expression*>(ret);
}
@@ -584,6 +600,18 @@ BinaryenExpressionRef BinaryenSelect(BinaryenModuleRef module, BinaryenExpressio
ret->finalize();
return static_cast<Expression*>(ret);
}
+BinaryenExpressionRef BinaryenDrop(BinaryenModuleRef module, BinaryenExpressionRef value) {
+ auto* ret = ((Module*)module)->allocator.alloc<Drop>();
+
+ if (tracing) {
+ auto id = noteExpression(ret);
+ std::cout << " expressions[" << id << "] = BinaryenDrop(the_module, expressions[" << expressions[value] << "]);\n";
+ }
+
+ ret->value = (Expression*)value;
+ ret->finalize();
+ return static_cast<Expression*>(ret);
+}
BinaryenExpressionRef BinaryenReturn(BinaryenModuleRef module, BinaryenExpressionRef value) {
auto* ret = Builder(*((Module*)module)).makeReturn((Expression*)value);
@@ -829,6 +857,17 @@ void BinaryenModuleOptimize(BinaryenModuleRef module) {
passRunner.run();
}
+void BinaryenModuleAutoDrop(BinaryenModuleRef module) {
+ if (tracing) {
+ std::cout << " BinaryenModuleAutoDrop(the_module);\n";
+ }
+
+ Module* wasm = (Module*)module;
+ PassRunner passRunner(wasm);
+ passRunner.add<AutoDrop>();
+ passRunner.run();
+}
+
size_t BinaryenModuleWrite(BinaryenModuleRef module, char* output, size_t outputSize) {
if (tracing) {
std::cout << " // BinaryenModuleWrite\n";
diff --git a/src/binaryen-c.h b/src/binaryen-c.h
index e2086072b..7ca5cd6c7 100644
--- a/src/binaryen-c.h
+++ b/src/binaryen-c.h
@@ -294,14 +294,16 @@ BinaryenExpressionRef BinaryenCallIndirect(BinaryenModuleRef module, BinaryenExp
// for more details.
BinaryenExpressionRef BinaryenGetLocal(BinaryenModuleRef module, BinaryenIndex index, BinaryenType type);
BinaryenExpressionRef BinaryenSetLocal(BinaryenModuleRef module, BinaryenIndex index, BinaryenExpressionRef value);
+BinaryenExpressionRef BinaryenTeeLocal(BinaryenModuleRef module, BinaryenIndex index, BinaryenExpressionRef value);
// Load: align can be 0, in which case it will be the natural alignment (equal to bytes)
BinaryenExpressionRef BinaryenLoad(BinaryenModuleRef module, uint32_t bytes, int8_t signed_, uint32_t offset, uint32_t align, BinaryenType type, BinaryenExpressionRef ptr);
// Store: align can be 0, in which case it will be the natural alignment (equal to bytes)
-BinaryenExpressionRef BinaryenStore(BinaryenModuleRef module, uint32_t bytes, uint32_t offset, uint32_t align, BinaryenExpressionRef ptr, BinaryenExpressionRef value);
+BinaryenExpressionRef BinaryenStore(BinaryenModuleRef module, uint32_t bytes, uint32_t offset, uint32_t align, BinaryenExpressionRef ptr, BinaryenExpressionRef value, BinaryenType type);
BinaryenExpressionRef BinaryenConst(BinaryenModuleRef module, struct BinaryenLiteral value);
BinaryenExpressionRef BinaryenUnary(BinaryenModuleRef module, BinaryenOp op, BinaryenExpressionRef value);
BinaryenExpressionRef BinaryenBinary(BinaryenModuleRef module, BinaryenOp op, BinaryenExpressionRef left, BinaryenExpressionRef right);
BinaryenExpressionRef BinaryenSelect(BinaryenModuleRef module, BinaryenExpressionRef condition, BinaryenExpressionRef ifTrue, BinaryenExpressionRef ifFalse);
+BinaryenExpressionRef BinaryenDrop(BinaryenModuleRef module, BinaryenExpressionRef value);
// Return: value can be NULL
BinaryenExpressionRef BinaryenReturn(BinaryenModuleRef module, BinaryenExpressionRef value);
// Host: name may be NULL
@@ -366,6 +368,11 @@ int BinaryenModuleValidate(BinaryenModuleRef module);
// Run the standard optimization passes on the module.
void BinaryenModuleOptimize(BinaryenModuleRef module);
+// Auto-generate drop() operations where needed. This lets you generate code without
+// worrying about where they are needed. (It is more efficient to do it yourself,
+// but simpler to use autodrop).
+void BinaryenModuleAutoDrop(BinaryenModuleRef module);
+
// Serialize a module into binary form.
// @return how many bytes were written. This will be less than or equal to bufferSize
size_t BinaryenModuleWrite(BinaryenModuleRef module, char* output, size_t outputSize);
diff --git a/src/passes/CoalesceLocals.cpp b/src/passes/CoalesceLocals.cpp
index 2c707b614..2063a20db 100644
--- a/src/passes/CoalesceLocals.cpp
+++ b/src/passes/CoalesceLocals.cpp
@@ -485,9 +485,19 @@ void CoalesceLocals::applyIndices(std::vector<Index>& indices, Expression* root)
// in addition, we can optimize out redundant copies and ineffective sets
GetLocal* get;
if ((get = set->value->dynCast<GetLocal>()) && get->index == set->index) {
- *action.origin = get; // further optimizations may get rid of the get, if this is in a place where the output does not matter
+ if (set->isTee()) {
+ *action.origin = get;
+ } else {
+ ExpressionManipulator::nop(set);
+ }
} else if (!action.effective) {
*action.origin = set->value; // value may have no side effects, further optimizations can eliminate it
+ if (!set->isTee()) {
+ // we need to drop it
+ Drop* drop = ExpressionManipulator::convert<SetLocal, Drop>(set);
+ drop->value = *action.origin;
+ *action.origin = drop;
+ }
}
}
}
diff --git a/src/passes/DeadCodeElimination.cpp b/src/passes/DeadCodeElimination.cpp
index de1191131..aba442e71 100644
--- a/src/passes/DeadCodeElimination.cpp
+++ b/src/passes/DeadCodeElimination.cpp
@@ -31,6 +31,7 @@
#include <wasm.h>
#include <pass.h>
#include <ast_utils.h>
+#include <wasm-builder.h>
namespace wasm {
@@ -191,6 +192,7 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V
case Expression::Id::UnaryId: DELEGATE(Unary);
case Expression::Id::BinaryId: DELEGATE(Binary);
case Expression::Id::SelectId: DELEGATE(Select);
+ case Expression::Id::DropId: DELEGATE(Drop);
case Expression::Id::ReturnId: DELEGATE(Return);
case Expression::Id::HostId: DELEGATE(Host);
case Expression::Id::NopId: DELEGATE(Nop);
@@ -226,6 +228,11 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V
// other things
+ Expression* drop(Expression* toDrop) {
+ if (toDrop->is<Unreachable>()) return toDrop;
+ return Builder(*getModule()).makeDrop(toDrop);
+ }
+
template<typename T>
void handleCall(T* curr, Expression* initial) {
for (Index i = 0; i < curr->operands.size(); i++) {
@@ -236,11 +243,11 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V
block->list.resize(newSize);
Index j = 0;
if (initial) {
- block->list[j] = initial;
+ block->list[j] = drop(initial);
j++;
}
for (; j < newSize; j++) {
- block->list[j] = curr->operands[j - (initial ? 1 : 0)];
+ block->list[j] = drop(curr->operands[j - (initial ? 1 : 0)]);
}
block->finalize();
replaceCurrent(block);
@@ -288,7 +295,7 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V
if (isDead(curr->value)) {
auto* block = getModule()->allocator.alloc<Block>();
block->list.resize(2);
- block->list[0] = curr->ptr;
+ block->list[0] = drop(curr->ptr);
block->list[1] = curr->value;
block->finalize();
replaceCurrent(block);
@@ -309,7 +316,7 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V
if (isDead(curr->right)) {
auto* block = getModule()->allocator.alloc<Block>();
block->list.resize(2);
- block->list[0] = curr->left;
+ block->list[0] = drop(curr->left);
block->list[1] = curr->right;
block->finalize();
replaceCurrent(block);
@@ -324,7 +331,7 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V
if (isDead(curr->ifFalse)) {
auto* block = getModule()->allocator.alloc<Block>();
block->list.resize(2);
- block->list[0] = curr->ifTrue;
+ block->list[0] = drop(curr->ifTrue);
block->list[1] = curr->ifFalse;
block->finalize();
replaceCurrent(block);
@@ -333,8 +340,8 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V
if (isDead(curr->condition)) {
auto* block = getModule()->allocator.alloc<Block>();
block->list.resize(3);
- block->list[0] = curr->ifTrue;
- block->list[1] = curr->ifFalse;
+ block->list[0] = drop(curr->ifTrue);
+ block->list[1] = drop(curr->ifFalse);
block->list[2] = curr->condition;
block->finalize();
replaceCurrent(block);
diff --git a/src/passes/MergeBlocks.cpp b/src/passes/MergeBlocks.cpp
index 686bb5d75..4798ad28d 100644
--- a/src/passes/MergeBlocks.cpp
+++ b/src/passes/MergeBlocks.cpp
@@ -74,10 +74,27 @@ struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks, Visitor<MergeBloc
void visitBlock(Block *curr) {
bool more = true;
+ bool changed = false;
while (more) {
more = false;
for (size_t i = 0; i < curr->list.size(); i++) {
Block* child = curr->list[i]->dynCast<Block>();
+ if (!child) {
+ // TODO: if we have a child that is (drop (block ..)) then we can move the drop into the block, allowing more merging,
+ // but we must also drop values from brs
+ /*
+ auto* drop = curr->list[i]->dynCast<Drop>();
+ if (drop) {
+ child = drop->value->dynCast<Block>();
+ if (child) {
+ // reuse the drop
+ drop->value = child->list.back();
+ child->list.back() = drop;
+ curr->list[i] = child;
+ }
+ }
+ */
+ }
if (!child) continue;
if (child->name.is()) continue; // named blocks can have breaks to them (and certainly do, if we ran RemoveUnusedNames and RemoveUnusedBrs)
ExpressionList merged(getModule()->allocator);
@@ -92,9 +109,11 @@ struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks, Visitor<MergeBloc
}
curr->list = merged;
more = true;
+ changed = true;
break;
}
}
+ if (changed) curr->finalize();
}
Block* optimize(Expression* curr, Expression*& child, Block* outer = nullptr, Expression** dependency1 = nullptr, Expression** dependency2 = nullptr) {
diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp
index 7486e07f1..ac349abc7 100644
--- a/src/passes/Print.cpp
+++ b/src/passes/Print.cpp
@@ -240,7 +240,12 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
printOpening(o, "get_local ") << printableLocal(curr->index) << ')';
}
void visitSetLocal(SetLocal *curr) {
- printOpening(o, "set_local ") << printableLocal(curr->index);
+ if (curr->isTee()) {
+ printOpening(o, "tee_local ");
+ } else {
+ printOpening(o, "set_local ");
+ }
+ o << printableLocal(curr->index);
incIndent();
printFullLine(curr->value);
decIndent();
@@ -282,7 +287,7 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
}
void visitStore(Store *curr) {
o << '(';
- prepareColor(o) << printWasmType(curr->type) << ".store";
+ prepareColor(o) << printWasmType(curr->valueType) << ".store";
if (curr->bytes < 4 || (curr->type == i64 && curr->bytes < 8)) {
if (curr->bytes == 1) {
o << '8';
@@ -467,6 +472,13 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
printFullLine(curr->condition);
decIndent();
}
+ void visitDrop(Drop *curr) {
+ o << '(';
+ prepareColor(o) << "drop";
+ incIndent();
+ printFullLine(curr->value);
+ decIndent();
+ }
void visitReturn(Return *curr) {
printOpening(o, "return");
if (!curr->value || curr->value->is<Nop>()) {
diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp
index d9786e62c..5315edac4 100644
--- a/src/passes/SimplifyLocals.cpp
+++ b/src/passes/SimplifyLocals.cpp
@@ -55,7 +55,13 @@ struct SetLocalRemover : public PostWalker<SetLocalRemover, Visitor<SetLocalRemo
void visitSetLocal(SetLocal *curr) {
if ((*numGetLocals)[curr->index] == 0) {
- replaceCurrent(curr->value);
+ auto* value = curr->value;
+ if (curr->isTee()) {
+ replaceCurrent(value);
+ } else {
+ Drop* drop = ExpressionManipulator::convert<SetLocal, Drop>(curr);
+ drop->value = value;
+ }
}
}
};
@@ -180,7 +186,10 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals,
auto found = sinkables.find(curr->index);
if (found != sinkables.end()) {
// sink it, and nop the origin
- replaceCurrent(*found->second.item);
+ auto* set = (*found->second.item)->cast<SetLocal>();
+ replaceCurrent(set);
+ assert(!set->isTee());
+ set->setTee(true);
// reuse the getlocal that is dying
*found->second.item = curr;
ExpressionManipulator::nop(curr);
@@ -189,6 +198,16 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals,
}
}
+ void visitDrop(Drop* curr) {
+ // collapse drop-tee into set, which can occur if a get was sunk into a tee
+ auto* set = curr->value->dynCast<SetLocal>();
+ if (set) {
+ assert(set->isTee());
+ set->setTee(false);
+ replaceCurrent(set);
+ }
+ }
+
void checkInvalidations(EffectAnalyzer& effects) {
// TODO: this is O(bad)
std::vector<Index> invalidated;
@@ -225,7 +244,11 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals,
// store is dead, leave just the value
auto found = self->sinkables.find(set->index);
if (found != self->sinkables.end()) {
- *found->second.item = (*found->second.item)->cast<SetLocal>()->value;
+ auto* previous = (*found->second.item)->cast<SetLocal>();
+ assert(!previous->isTee());
+ auto* previousValue = previous->value;
+ Drop* drop = ExpressionManipulator::convert<SetLocal, Drop>(previous);
+ drop->value = previousValue;
self->sinkables.erase(found);
self->anotherCycle = true;
}
@@ -236,15 +259,10 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals,
self->checkInvalidations(effects);
}
- if (set) {
- // we may be a replacement for the current node, update the stack
- self->expressionStack.pop_back();
- self->expressionStack.push_back(set);
- if (!ExpressionAnalyzer::isResultUsed(self->expressionStack, self->getFunction())) {
- Index index = set->index;
- assert(self->sinkables.count(index) == 0);
- self->sinkables.emplace(std::make_pair(index, SinkableInfo(currp)));
- }
+ if (set && !set->isTee()) {
+ Index index = set->index;
+ assert(self->sinkables.count(index) == 0);
+ self->sinkables.emplace(std::make_pair(index, SinkableInfo(currp)));
}
self->expressionStack.pop_back();
diff --git a/src/passes/Vacuum.cpp b/src/passes/Vacuum.cpp
index 0195427d7..1b0887181 100644
--- a/src/passes/Vacuum.cpp
+++ b/src/passes/Vacuum.cpp
@@ -41,6 +41,7 @@ struct Vacuum : public WalkerPass<PostWalker<Vacuum, Visitor<Vacuum>>> {
case Expression::Id::BlockId: return curr; // not always needed, but handled in visitBlock()
case Expression::Id::IfId: return curr; // not always needed, but handled in visitIf()
case Expression::Id::LoopId: return curr; // not always needed, but handled in visitLoop()
+ case Expression::Id::DropId: return curr; // not always needed, but handled in visitDrop()
case Expression::Id::BreakId:
case Expression::Id::SwitchId:
@@ -198,6 +199,13 @@ struct Vacuum : public WalkerPass<PostWalker<Vacuum, Visitor<Vacuum>>> {
if (curr->body->is<Nop>()) ExpressionManipulator::nop(curr);
}
+ void visitDrop(Drop* curr) {
+ // if the drop input has no side effects, it can be wiped out
+ if (!EffectAnalyzer(curr->value).hasSideEffects()) {
+ ExpressionManipulator::nop(curr);
+ }
+ }
+
static void visitPre(Vacuum* self, Expression** currp) {
self->expressionStack.push_back(*currp);
}
diff --git a/src/s2wasm.h b/src/s2wasm.h
index a23cbb549..5226ceda3 100644
--- a/src/s2wasm.h
+++ b/src/s2wasm.h
@@ -753,6 +753,7 @@ class S2WasmBuilder {
set->index = func->getLocalIndex(assign);
set->value = curr;
set->type = curr->type;
+ set->setTee(false);
addToBlock(set);
}
};
@@ -834,7 +835,7 @@ class S2WasmBuilder {
auto makeStore = [&](WasmType type) {
skipComma();
auto curr = allocator->alloc<Store>();
- curr->type = type;
+ curr->valueType = type;
int32_t bytes = getInt() / CHAR_BIT;
curr->bytes = bytes > 0 ? bytes : getWasmTypeSize(type);
Name assign = getAssign();
@@ -849,6 +850,7 @@ class S2WasmBuilder {
curr->align = 1U << getInt(attributes[0] + 8);
}
curr->value = inputs[1];
+ curr->finalize();
setOutput(curr, assign);
};
auto makeSelect = [&](WasmType type) {
@@ -1146,6 +1148,7 @@ class S2WasmBuilder {
Name assign = getAssign();
skipComma();
auto curr = allocator->alloc<SetLocal>();
+ curr->setTee(true);
curr->index = func->getLocalIndex(getAssign());
skipComma();
curr->value = getInput();
diff --git a/src/shell-interface.h b/src/shell-interface.h
index f332307ad..2b0c491b9 100644
--- a/src/shell-interface.h
+++ b/src/shell-interface.h
@@ -167,7 +167,7 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface {
}
void store(Store* store, Address addr, Literal value) override {
- switch (store->type) {
+ switch (store->valueType) {
case i32: {
switch (store->bytes) {
case 1: memory.set<int8_t>(addr, value.geti32()); break;
diff --git a/src/wasm-binary.h b/src/wasm-binary.h
index 04bebddc5..791a0459b 100644
--- a/src/wasm-binary.h
+++ b/src/wasm-binary.h
@@ -412,6 +412,7 @@ enum ASTNodes {
CallFunction = 0x16,
CallIndirect = 0x17,
CallImport = 0x18,
+ TeeLocal = 0x19,
GetGlobal = 0x1a,
SetGlobal = 0x1b,
@@ -426,6 +427,7 @@ enum ASTNodes {
TableSwitch = 0x08,
Return = 0x09,
Unreachable = 0x0a,
+ Drop = 0x0b,
End = 0x0f
};
@@ -939,9 +941,9 @@ public:
o << int8_t(BinaryConsts::GetLocal) << U32LEB(mappedLocals[curr->index]);
}
void visitSetLocal(SetLocal *curr) {
- if (debug) std::cerr << "zz node: SetLocal" << std::endl;
+ if (debug) std::cerr << "zz node: Set|TeeLocal" << std::endl;
recurse(curr->value);
- o << int8_t(BinaryConsts::SetLocal) << U32LEB(mappedLocals[curr->index]);
+ o << int8_t(curr->isTee() ? BinaryConsts::TeeLocal : BinaryConsts::SetLocal) << U32LEB(mappedLocals[curr->index]);
}
void visitGetGlobal(GetGlobal *curr) {
if (debug) std::cerr << "zz node: GetGlobal " << (o.size() + 1) << std::endl;
@@ -991,7 +993,7 @@ public:
if (debug) std::cerr << "zz node: Store" << std::endl;
recurse(curr->ptr);
recurse(curr->value);
- switch (curr->type) {
+ switch (curr->valueType) {
case i32: {
switch (curr->bytes) {
case 1: o << int8_t(BinaryConsts::I32StoreMem8); break;
@@ -1219,6 +1221,11 @@ public:
if (debug) std::cerr << "zz node: Unreachable" << std::endl;
o << int8_t(BinaryConsts::Unreachable);
}
+ void visitDrop(Drop *curr) {
+ if (debug) std::cerr << "zz node: Drop" << std::endl;
+ recurse(curr->value);
+ o << int8_t(BinaryConsts::Drop);
+ }
};
class WasmBinaryBuilder {
@@ -1728,13 +1735,15 @@ public:
case BinaryConsts::CallImport: visitCallImport((curr = allocator.alloc<CallImport>())->cast<CallImport>()); break;
case BinaryConsts::CallIndirect: visitCallIndirect((curr = allocator.alloc<CallIndirect>())->cast<CallIndirect>()); break;
case BinaryConsts::GetLocal: visitGetLocal((curr = allocator.alloc<GetLocal>())->cast<GetLocal>()); break;
- case BinaryConsts::SetLocal: visitSetLocal((curr = allocator.alloc<SetLocal>())->cast<SetLocal>()); break;
+ case BinaryConsts::TeeLocal:
+ case BinaryConsts::SetLocal: visitSetLocal((curr = allocator.alloc<SetLocal>())->cast<SetLocal>(), code); break;
case BinaryConsts::GetGlobal: visitGetGlobal((curr = allocator.alloc<GetGlobal>())->cast<GetGlobal>()); break;
case BinaryConsts::SetGlobal: visitSetGlobal((curr = allocator.alloc<SetGlobal>())->cast<SetGlobal>()); break;
case BinaryConsts::Select: visitSelect((curr = allocator.alloc<Select>())->cast<Select>()); break;
case BinaryConsts::Return: visitReturn((curr = allocator.alloc<Return>())->cast<Return>()); break;
case BinaryConsts::Nop: visitNop((curr = allocator.alloc<Nop>())->cast<Nop>()); break;
case BinaryConsts::Unreachable: visitUnreachable((curr = allocator.alloc<Unreachable>())->cast<Unreachable>()); break;
+ case BinaryConsts::Drop: visitDrop((curr = allocator.alloc<Drop>())->cast<Drop>()); break;
case BinaryConsts::End:
case BinaryConsts::Else: curr = nullptr; break;
default: {
@@ -1922,12 +1931,13 @@ public:
assert(curr->index < currFunction->getNumLocals());
curr->type = currFunction->getLocalType(curr->index);
}
- void visitSetLocal(SetLocal *curr) {
- if (debug) std::cerr << "zz node: SetLocal" << std::endl;
+ void visitSetLocal(SetLocal *curr, uint8_t code) {
+ if (debug) std::cerr << "zz node: Set|TeeLocal" << std::endl;
curr->index = getU32LEB();
assert(curr->index < currFunction->getNumLocals());
curr->value = popExpression();
curr->type = curr->value->type;
+ curr->setTee(code == BinaryConsts::TeeLocal);
}
void visitGetGlobal(GetGlobal *curr) {
if (debug) std::cerr << "zz node: GetGlobal " << pos << std::endl;
@@ -1976,21 +1986,22 @@ public:
bool maybeVisitStore(Expression*& out, uint8_t code) {
Store* curr;
switch (code) {
- case BinaryConsts::I32StoreMem8: curr = allocator.alloc<Store>(); curr->bytes = 1; curr->type = i32; break;
- case BinaryConsts::I32StoreMem16: curr = allocator.alloc<Store>(); curr->bytes = 2; curr->type = i32; break;
- case BinaryConsts::I32StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 4; curr->type = i32; break;
- case BinaryConsts::I64StoreMem8: curr = allocator.alloc<Store>(); curr->bytes = 1; curr->type = i64; break;
- case BinaryConsts::I64StoreMem16: curr = allocator.alloc<Store>(); curr->bytes = 2; curr->type = i64; break;
- case BinaryConsts::I64StoreMem32: curr = allocator.alloc<Store>(); curr->bytes = 4; curr->type = i64; break;
- case BinaryConsts::I64StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 8; curr->type = i64; break;
- case BinaryConsts::F32StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 4; curr->type = f32; break;
- case BinaryConsts::F64StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 8; curr->type = f64; break;
+ case BinaryConsts::I32StoreMem8: curr = allocator.alloc<Store>(); curr->bytes = 1; curr->valueType = i32; break;
+ case BinaryConsts::I32StoreMem16: curr = allocator.alloc<Store>(); curr->bytes = 2; curr->valueType = i32; break;
+ case BinaryConsts::I32StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 4; curr->valueType = i32; break;
+ case BinaryConsts::I64StoreMem8: curr = allocator.alloc<Store>(); curr->bytes = 1; curr->valueType = i64; break;
+ case BinaryConsts::I64StoreMem16: curr = allocator.alloc<Store>(); curr->bytes = 2; curr->valueType = i64; break;
+ case BinaryConsts::I64StoreMem32: curr = allocator.alloc<Store>(); curr->bytes = 4; curr->valueType = i64; break;
+ case BinaryConsts::I64StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 8; curr->valueType = i64; break;
+ case BinaryConsts::F32StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 4; curr->valueType = f32; break;
+ case BinaryConsts::F64StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 8; curr->valueType = f64; break;
default: return false;
}
if (debug) std::cerr << "zz node: Store" << std::endl;
readMemoryAccess(curr->align, curr->bytes, curr->offset);
curr->value = popExpression();
curr->ptr = popExpression();
+ curr->finalize();
out = curr;
return true;
}
@@ -2175,6 +2186,10 @@ public:
void visitUnreachable(Unreachable *curr) {
if (debug) std::cerr << "zz node: Unreachable" << std::endl;
}
+ void visitDrop(Drop *curr) {
+ if (debug) std::cerr << "zz node: Drop" << std::endl;
+ curr->value = popExpression();
+ }
};
} // namespace wasm
diff --git a/src/wasm-builder.h b/src/wasm-builder.h
index 22e1f9a00..3cba351a2 100644
--- a/src/wasm-builder.h
+++ b/src/wasm-builder.h
@@ -142,6 +142,13 @@ public:
auto* ret = allocator.alloc<SetLocal>();
ret->index = index;
ret->value = value;
+ ret->type = none;
+ return ret;
+ }
+ SetLocal* makeTeeLocal(Index index, Expression* value) {
+ auto* ret = allocator.alloc<SetLocal>();
+ ret->index = index;
+ ret->value = value;
ret->type = value->type;
return ret;
}
@@ -164,10 +171,11 @@ public:
ret->type = type;
return ret;
}
- Store* makeStore(unsigned bytes, uint32_t offset, unsigned align, Expression *ptr, Expression *value) {
+ Store* makeStore(unsigned bytes, uint32_t offset, unsigned align, Expression *ptr, Expression *value, WasmType type) {
auto* ret = allocator.alloc<Store>();
- ret->bytes = bytes; ret->offset = offset; ret->align = align; ret->ptr = ptr; ret->value = value;
- ret->type = value->type;
+ ret->bytes = bytes; ret->offset = offset; ret->align = align; ret->ptr = ptr; ret->value = value; ret->valueType = type;
+ ret->finalize();
+ assert(isConcreteWasmType(ret->value->type) ? ret->value->type == type : true);
return ret;
}
Const* makeConst(Literal value) {
@@ -205,12 +213,22 @@ public:
ret->op = op;
ret->nameOperand = nameOperand;
ret->operands.set(operands);
+ ret->finalize();
return ret;
}
Unreachable* makeUnreachable() {
return allocator.alloc<Unreachable>();
}
+ // Additional helpers
+
+ Drop* makeDrop(Expression *value) {
+ auto* ret = allocator.alloc<Drop>();
+ ret->value = value;
+ ret->finalize();
+ return ret;
+ }
+
// Additional utility functions for building on top of nodes
static Index addParam(Function* func, Name name, WasmType type) {
diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h
index 28a3e8e6d..9d1b5e522 100644
--- a/src/wasm-interpreter.h
+++ b/src/wasm-interpreter.h
@@ -435,6 +435,12 @@ public:
NOTE_EVAL1(condition.value);
return condition.value.geti32() ? ifTrue : ifFalse; // ;-)
}
+ Flow visitDrop(Drop *curr) {
+ NOTE_ENTER("Drop");
+ Flow value = visit(curr->value);
+ if (value.breaking()) return value;
+ return Flow();
+ }
Flow visitReturn(Return *curr) {
NOTE_ENTER("Return");
Flow flow;
@@ -693,9 +699,9 @@ public:
if (flow.breaking()) return flow;
NOTE_EVAL1(index);
NOTE_EVAL1(flow.value);
- assert(flow.value.type == curr->type);
+ assert(curr->isTee() ? flow.value.type == curr->type : true);
scope.locals[index] = flow.value;
- return flow;
+ return curr->isTee() ? flow : Flow();
}
Flow visitGetGlobal(GetGlobal *curr) {
@@ -730,7 +736,7 @@ public:
Flow value = visit(curr->value);
if (value.breaking()) return value;
instance.externalInterface->store(curr, instance.getFinalAddress(curr, ptr.value), value.value);
- return value;
+ return Flow();
}
Flow visitHost(Host *curr) {
diff --git a/src/wasm-js.cpp b/src/wasm-js.cpp
index 83956e47e..66dcb4864 100644
--- a/src/wasm-js.cpp
+++ b/src/wasm-js.cpp
@@ -411,11 +411,11 @@ extern "C" void EMSCRIPTEN_KEEPALIVE instantiate() {
Module["info"].parent["HEAPU8"][addr + i] = HEAPU8[i];
}
HEAP32[0] = save0; HEAP32[1] = save1;
- }, (uint32_t)addr, store_->bytes, isWasmTypeFloat(store_->type), isWasmTypeFloat(store_->type) ? value.getFloat() : (double)value.getInteger());
+ }, (uint32_t)addr, store_->bytes, isWasmTypeFloat(store_->valueType), isWasmTypeFloat(store_->valueType) ? value.getFloat() : (double)value.getInteger());
return;
}
// nicely aligned
- if (!isWasmTypeFloat(store_->type)) {
+ if (!isWasmTypeFloat(store_->valueType)) {
if (store_->bytes == 1) {
EM_ASM_INT({ Module['info'].parent['HEAP8'][$0] = $1 }, addr, value.geti32());
} else if (store_->bytes == 2) {
diff --git a/src/wasm-s-parser.h b/src/wasm-s-parser.h
index 746d96cf2..685f26d82 100644
--- a/src/wasm-s-parser.h
+++ b/src/wasm-s-parser.h
@@ -568,7 +568,6 @@ private:
if (str[1] == '6' && str[2] == '4' && (prefix || str[3] == 0)) return f64;
}
if (allowError) return none;
- throw ParseException("unknown type");
abort();
}
@@ -621,7 +620,7 @@ public:
if (op[3] == '_') return makeBinary(s, op[4] == 'u' ? BINARY_INT(DivU) : BINARY_INT(DivS), type);
if (op[3] == 0) return makeBinary(s, BINARY_FLOAT(Div), type);
}
- if (op[1] == 'e') return makeUnary(s, UnaryOp::DemoteFloat64, type);
+ if (op[1] == 'e') return makeUnary(s, UnaryOp::DemoteFloat64, type);
abort_on(op);
}
case 'e': {
@@ -739,6 +738,10 @@ public:
} else if (str[1] == 'u') return makeHost(s, HostOp::CurrentMemory);
abort_on(str);
}
+ case 'd': {
+ if (str[1] == 'r') return makeDrop(s);
+ abort_on(str);
+ }
case 'e': {
if (str[1] == 'l') return makeThenOrElse(s);
abort_on(str);
@@ -785,6 +788,7 @@ public:
}
case 't': {
if (str[1] == 'h') return makeThenOrElse(s);
+ if (str[1] == 'e' && str[2] == 'e') return makeTeeLocal(s);
abort_on(str);
}
case 'u': {
@@ -878,6 +882,13 @@ private:
return ret;
}
+ Expression* makeDrop(Element& s) {
+ auto ret = allocator.alloc<Drop>();
+ ret->value = parseExpression(s[1]);
+ ret->finalize();
+ return ret;
+ }
+
Expression* makeHost(Element& s, HostOp op) {
auto ret = allocator.alloc<Host>();
ret->op = op;
@@ -910,11 +921,18 @@ private:
return ret;
}
+ Expression* makeTeeLocal(Element& s) {
+ auto ret = allocator.alloc<SetLocal>();
+ ret->index = getLocalIndex(*s[1]);
+ ret->value = parseExpression(s[2]);
+ ret->setTee(true);
+ return ret;
+ }
Expression* makeSetLocal(Element& s) {
auto ret = allocator.alloc<SetLocal>();
ret->index = getLocalIndex(*s[1]);
ret->value = parseExpression(s[2]);
- ret->type = currFunction->getLocalType(ret->index);
+ ret->setTee(false);
return ret;
}
@@ -1061,7 +1079,7 @@ private:
Expression* makeStore(Element& s, WasmType type) {
const char *extra = strchr(s[0]->c_str(), '.') + 6; // after "type.store"
auto ret = allocator.alloc<Store>();
- ret->type = type;
+ ret->valueType = type;
ret->bytes = getWasmTypeSize(type);
if (extra[0] == '8') {
ret->bytes = 1;
@@ -1092,6 +1110,7 @@ private:
}
ret->ptr = parseExpression(s[i]);
ret->value = parseExpression(s[i+1]);
+ ret->finalize();
return ret;
}
diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h
index b50ca0fb2..d6484abdb 100644
--- a/src/wasm-traversal.h
+++ b/src/wasm-traversal.h
@@ -53,6 +53,7 @@ struct Visitor {
ReturnType visitUnary(Unary *curr) {}
ReturnType visitBinary(Binary *curr) {}
ReturnType visitSelect(Select *curr) {}
+ ReturnType visitDrop(Drop *curr) {}
ReturnType visitReturn(Return *curr) {}
ReturnType visitHost(Host *curr) {}
ReturnType visitNop(Nop *curr) {}
@@ -93,6 +94,7 @@ struct Visitor {
case Expression::Id::UnaryId: DELEGATE(Unary);
case Expression::Id::BinaryId: DELEGATE(Binary);
case Expression::Id::SelectId: DELEGATE(Select);
+ case Expression::Id::DropId: DELEGATE(Drop);
case Expression::Id::ReturnId: DELEGATE(Return);
case Expression::Id::HostId: DELEGATE(Host);
case Expression::Id::NopId: DELEGATE(Nop);
@@ -132,6 +134,7 @@ struct UnifiedExpressionVisitor : public Visitor<SubType> {
ReturnType visitUnary(Unary *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
ReturnType visitBinary(Binary *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
ReturnType visitSelect(Select *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
+ ReturnType visitDrop(Drop *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
ReturnType visitReturn(Return *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
ReturnType visitHost(Host *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
ReturnType visitNop(Nop *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
@@ -264,14 +267,15 @@ struct Walker : public VisitorType {
static void doVisitCallIndirect(SubType* self, Expression** currp) { self->visitCallIndirect((*currp)->cast<CallIndirect>()); }
static void doVisitGetLocal(SubType* self, Expression** currp) { self->visitGetLocal((*currp)->cast<GetLocal>()); }
static void doVisitSetLocal(SubType* self, Expression** currp) { self->visitSetLocal((*currp)->cast<SetLocal>()); }
- static void doVisitGetGlobal(SubType* self, Expression** currp) { self->visitGetGlobal((*currp)->cast<GetGlobal>()); }
- static void doVisitSetGlobal(SubType* self, Expression** currp) { self->visitSetGlobal((*currp)->cast<SetGlobal>()); }
+ static void doVisitGetGlobal(SubType* self, Expression** currp) { self->visitGetGlobal((*currp)->cast<GetGlobal>()); }
+ static void doVisitSetGlobal(SubType* self, Expression** currp) { self->visitSetGlobal((*currp)->cast<SetGlobal>()); }
static void doVisitLoad(SubType* self, Expression** currp) { self->visitLoad((*currp)->cast<Load>()); }
static void doVisitStore(SubType* self, Expression** currp) { self->visitStore((*currp)->cast<Store>()); }
static void doVisitConst(SubType* self, Expression** currp) { self->visitConst((*currp)->cast<Const>()); }
static void doVisitUnary(SubType* self, Expression** currp) { self->visitUnary((*currp)->cast<Unary>()); }
static void doVisitBinary(SubType* self, Expression** currp) { self->visitBinary((*currp)->cast<Binary>()); }
static void doVisitSelect(SubType* self, Expression** currp) { self->visitSelect((*currp)->cast<Select>()); }
+ static void doVisitDrop(SubType* self, Expression** currp) { self->visitDrop((*currp)->cast<Drop>()); }
static void doVisitReturn(SubType* self, Expression** currp) { self->visitReturn((*currp)->cast<Return>()); }
static void doVisitHost(SubType* self, Expression** currp) { self->visitHost((*currp)->cast<Host>()); }
static void doVisitNop(SubType* self, Expression** currp) { self->visitNop((*currp)->cast<Nop>()); }
@@ -411,6 +415,11 @@ struct PostWalker : public Walker<SubType, VisitorType> {
self->pushTask(SubType::scan, &curr->cast<Select>()->ifTrue);
break;
}
+ case Expression::Id::DropId: {
+ self->pushTask(SubType::doVisitDrop, currp);
+ self->pushTask(SubType::scan, &curr->cast<Drop>()->value);
+ break;
+ }
case Expression::Id::ReturnId: {
self->pushTask(SubType::doVisitReturn, currp);
self->maybePushTask(SubType::scan, &curr->cast<Return>()->value);
diff --git a/src/wasm-validator.h b/src/wasm-validator.h
index 3d9a0e1c3..b4818e253 100644
--- a/src/wasm-validator.h
+++ b/src/wasm-validator.h
@@ -30,7 +30,16 @@ struct WasmValidator : public PostWalker<WasmValidator, Visitor<WasmValidator>>
bool valid = true;
bool validateWebConstraints = false;
- std::map<Name, WasmType> breakTypes; // breaks to a label must all have the same type, and the right type
+ struct BreakInfo {
+ WasmType type;
+ Index arity;
+ BreakInfo() {}
+ BreakInfo(WasmType type, Index arity) : type(type), arity(arity) {}
+ };
+
+ std::map<Name, std::vector<Expression*>> breakTargets; // more than one block/loop may use a label name, so stack them
+ std::map<Expression*, BreakInfo> breakInfos;
+
WasmType returnType = unreachable; // type used in returns
public:
@@ -42,55 +51,107 @@ public:
// visitors
+ static void visitPreBlock(WasmValidator* self, Expression** currp) {
+ auto* curr = (*currp)->cast<Block>();
+ if (curr->name.is()) self->breakTargets[curr->name].push_back(curr);
+ }
+
void visitBlock(Block *curr) {
// if we are break'ed to, then the value must be right for us
if (curr->name.is()) {
- // none or unreachable means a poison value that we should ignore - if consumed, it will error
- if (breakTypes.count(curr->name) > 0 && isConcreteWasmType(breakTypes[curr->name]) && isConcreteWasmType(curr->type)) {
- shouldBeEqual(curr->type, breakTypes[curr->name], curr, "block+breaks must have right type if breaks return a value");
+ if (breakInfos.count(curr) > 0) {
+ auto& info = breakInfos[curr];
+ // none or unreachable means a poison value that we should ignore - if consumed, it will error
+ if (isConcreteWasmType(info.type) && isConcreteWasmType(curr->type)) {
+ shouldBeEqual(curr->type, info.type, curr, "block+breaks must have right type if breaks return a value");
+ }
+ shouldBeTrue(info.arity != Index(-1), curr, "break arities must match");
+ if (curr->list.size() > 0) {
+ auto last = curr->list.back()->type;
+ if (isConcreteWasmType(last) && info.type != unreachable) {
+ shouldBeEqual(last, info.type, curr, "block+breaks must have right type if block ends with a reachable value");
+ }
+ if (last == none) {
+ shouldBeTrue(info.arity == Index(0), curr, "if block ends with a none, breaks cannot send a value of any type");
+ }
+ }
+ }
+ breakTargets[curr->name].pop_back();
+ }
+ if (curr->list.size() > 1) {
+ for (Index i = 0; i < curr->list.size() - 1; i++) {
+ if (!shouldBeTrue(!isConcreteWasmType(curr->list[i]->type), curr, "non-final block elements returning a value must be drop()ed (binaryen's autodrop option might help you)")) {
+ std::cerr << "(on index " << i << ":\n" << curr->list[i] << "\n), type: " << curr->list[i]->type << "\n";
+ }
}
- breakTypes.erase(curr->name);
}
}
- void visitIf(If *curr) {
- shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32 || curr->condition->type == i64, curr, "if condition must be valid");
+
+ static void visitPreLoop(WasmValidator* self, Expression** currp) {
+ auto* curr = (*currp)->cast<Loop>();
+ if (curr->in.is()) self->breakTargets[curr->in].push_back(curr);
+ if (curr->out.is()) self->breakTargets[curr->out].push_back(curr);
}
+
void visitLoop(Loop *curr) {
if (curr->in.is()) {
- breakTypes.erase(curr->in);
+ breakTargets[curr->in].pop_back();
}
if (curr->out.is()) {
- breakTypes.erase(curr->out);
+ breakTargets[curr->out].pop_back();
}
}
- void noteBreak(Name name, Expression* value) {
+
+ void visitIf(If *curr) {
+ shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32 || curr->condition->type == i64, curr, "if condition must be valid");
+ }
+
+ // override scan to add a pre and a post check task to all nodes
+ static void scan(WasmValidator* self, Expression** currp) {
+ PostWalker<WasmValidator, Visitor<WasmValidator>>::scan(self, currp);
+
+ auto* curr = *currp;
+ if (curr->is<Block>()) self->pushTask(visitPreBlock, currp);
+ if (curr->is<Loop>()) self->pushTask(visitPreLoop, currp);
+ }
+
+ void noteBreak(Name name, Expression* value, Expression* curr) {
WasmType valueType = none;
+ Index arity = 0;
if (value) {
valueType = value->type;
+ shouldBeUnequal(valueType, none, curr, "breaks must have a valid value");
+ arity = 1;
}
- if (breakTypes.count(name) == 0) {
- breakTypes[name] = valueType;
+ if (!shouldBeTrue(breakTargets[name].size() > 0, curr, "all break targets must be valid")) return;
+ auto* target = breakTargets[name].back();
+ if (breakInfos.count(target) == 0) {
+ breakInfos[target] = BreakInfo(valueType, arity);
} else {
- if (breakTypes[name] == unreachable) {
- breakTypes[name] = valueType;
+ auto& info = breakInfos[target];
+ if (info.type == unreachable) {
+ info.type = valueType;
} else if (valueType != unreachable) {
- if (valueType != breakTypes[name]) {
- breakTypes[name] = none; // a poison value that must not be consumed
+ if (valueType != info.type) {
+ info.type = none; // a poison value that must not be consumed
}
}
+ if (arity != info.arity) {
+ info.arity = Index(-1); // a poison value
+ }
}
}
void visitBreak(Break *curr) {
- noteBreak(curr->name, curr->value);
+ noteBreak(curr->name, curr->value, curr);
if (curr->condition) {
shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "break condition must be i32");
}
}
void visitSwitch(Switch *curr) {
for (auto& target : curr->targets) {
- noteBreak(target, curr->value);
+ noteBreak(target, curr->value, curr);
}
- noteBreak(curr->default_, curr->value);
+ noteBreak(curr->default_, curr->value, curr);
shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "br_table condition must be i32");
}
void visitCall(Call *curr) {
@@ -128,7 +189,9 @@ public:
void visitSetLocal(SetLocal *curr) {
shouldBeTrue(curr->index < getFunction()->getNumLocals(), curr, "set_local index must be small enough");
if (curr->value->type != unreachable) {
- shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "set_local type must be correct");
+ if (curr->type != none) { // tee is ok anyhow
+ shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "set_local type must be correct");
+ }
shouldBeEqual(getFunction()->getLocalType(curr->index), curr->value->type, curr, "set_local type must match function");
}
}
@@ -139,7 +202,8 @@ public:
void visitStore(Store *curr) {
validateAlignment(curr->align);
shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "store pointer type must be i32");
- shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "store value type must match");
+ shouldBeUnequal(curr->value->type, none, curr, "store value type must not be none");
+ // TODO: enable a check that replaces this, for type being none shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "store value type must match");
}
void visitBinary(Binary *curr) {
if (curr->left->type != unreachable && curr->right->type != unreachable) {
@@ -268,13 +332,11 @@ public:
void visitFunction(Function *curr) {
// if function has no result, it is ignored
// if body is unreachable, it might be e.g. a return
- if (curr->result != none) {
- if (curr->body->type != unreachable) {
- shouldBeEqual(curr->result, curr->body->type, curr->body, "function body type must match, if function returns");
- }
- if (returnType != unreachable) {
- shouldBeEqual(curr->result, returnType, curr->body, "function result must match, if function returns");
- }
+ if (curr->body->type != unreachable) {
+ shouldBeEqual(curr->result, curr->body->type, curr->body, "function body type must match, if function returns");
+ }
+ if (returnType != unreachable) {
+ shouldBeEqual(curr->result, returnType, curr->body, "function result must match, if function returns");
}
returnType = unreachable;
}
@@ -311,12 +373,6 @@ public:
void doWalkFunction(Function* func) {
PostWalker<WasmValidator, Visitor<WasmValidator>>::doWalkFunction(func);
- if (!shouldBeTrue(breakTypes.size() == 0, "break targets", "all break targets must be valid")) {
- for (auto& target : breakTypes) {
- std::cerr << " - " << target.first << '\n';
- }
- breakTypes.clear();
- }
}
private:
diff --git a/src/wasm.h b/src/wasm.h
index 68558033d..ee3e12910 100644
--- a/src/wasm.h
+++ b/src/wasm.h
@@ -802,7 +802,7 @@ enum UnaryOp {
ConvertSInt32ToFloat32, ConvertSInt32ToFloat64, ConvertUInt32ToFloat32, ConvertUInt32ToFloat64, ConvertSInt64ToFloat32, ConvertSInt64ToFloat64, ConvertUInt64ToFloat32, ConvertUInt64ToFloat64, // int to float
PromoteFloat32, // f32 to f64
DemoteFloat64, // f64 to f32
- ReinterpretInt32, ReinterpretInt64 // reinterpret bits to float
+ ReinterpretInt32, ReinterpretInt64, // reinterpret bits to float
};
enum BinaryOp {
@@ -877,6 +877,7 @@ public:
UnaryId,
BinaryId,
SelectId,
+ DropId,
ReturnId,
HostId,
NopId,
@@ -929,6 +930,7 @@ inline const char *getExpressionName(Expression *curr) {
case Expression::Id::UnaryId: return "unary";
case Expression::Id::BinaryId: return "binary";
case Expression::Id::SelectId: return "select";
+ case Expression::Id::DropId: return "drop";
case Expression::Id::ReturnId: return "return";
case Expression::Id::HostId: return "host";
case Expression::Id::NopId: return "nop";
@@ -1108,8 +1110,13 @@ public:
Index index;
Expression *value;
- void finalize() {
- type = value->type;
+ bool isTee() {
+ return type != none;
+ }
+
+ void setTee(bool is) {
+ if (is) type = value->type;
+ else type = none;
}
};
@@ -1150,16 +1157,17 @@ public:
class Store : public SpecificExpression<Expression::StoreId> {
public:
- Store() {}
- Store(MixedArena& allocator) {}
+ Store() : valueType(none) {}
+ Store(MixedArena& allocator) : Store() {}
uint8_t bytes;
Address offset;
Address align;
Expression *ptr, *value;
+ WasmType valueType; // the store never returns a value
void finalize() {
- type = value->type;
+ assert(valueType != none); // must be set
}
};
@@ -1312,6 +1320,14 @@ public:
}
};
+class Drop : public SpecificExpression<Expression::DropId> {
+public:
+ Drop() {}
+ Drop(MixedArena& allocator) {}
+
+ Expression *value;
+};
+
class Return : public SpecificExpression<Expression::ReturnId> {
public:
Return() : value(nullptr) {