diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/asm2wasm.h | 28 | ||||
-rw-r--r-- | src/ast_utils.h | 50 | ||||
-rw-r--r-- | src/binaryen-c.cpp | 43 | ||||
-rw-r--r-- | src/binaryen-c.h | 9 | ||||
-rw-r--r-- | src/passes/CoalesceLocals.cpp | 12 | ||||
-rw-r--r-- | src/passes/DeadCodeElimination.cpp | 21 | ||||
-rw-r--r-- | src/passes/MergeBlocks.cpp | 19 | ||||
-rw-r--r-- | src/passes/Print.cpp | 16 | ||||
-rw-r--r-- | src/passes/SimplifyLocals.cpp | 42 | ||||
-rw-r--r-- | src/passes/Vacuum.cpp | 8 | ||||
-rw-r--r-- | src/s2wasm.h | 5 | ||||
-rw-r--r-- | src/shell-interface.h | 2 | ||||
-rw-r--r-- | src/wasm-binary.h | 45 | ||||
-rw-r--r-- | src/wasm-builder.h | 24 | ||||
-rw-r--r-- | src/wasm-interpreter.h | 12 | ||||
-rw-r--r-- | src/wasm-js.cpp | 4 | ||||
-rw-r--r-- | src/wasm-s-parser.h | 27 | ||||
-rw-r--r-- | src/wasm-traversal.h | 13 | ||||
-rw-r--r-- | src/wasm-validator.h | 124 | ||||
-rw-r--r-- | src/wasm.h | 28 |
20 files changed, 421 insertions, 111 deletions
diff --git a/src/asm2wasm.h b/src/asm2wasm.h index ce3c0749d..857145937 100644 --- a/src/asm2wasm.h +++ b/src/asm2wasm.h @@ -793,6 +793,7 @@ void Asm2WasmBuilder::processAsm(Ref ast) { }; PassRunner passRunner(&wasm); passRunner.add<FinalizeCalls>(this); + passRunner.add<AutoDrop>(); passRunner.run(); // apply memory growth, if relevant @@ -802,7 +803,7 @@ void Asm2WasmBuilder::processAsm(Ref ast) { wasm.addFunction(builder.makeFunction( GROW_WASM_MEMORY, { { NEW_SIZE, i32 } }, - none, + i32, {}, builder.makeHost( GrowMemory, @@ -1009,7 +1010,8 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { auto ret = allocator.alloc<SetLocal>(); ret->index = function->getLocalIndex(ast[2][1]->getIString()); ret->value = process(ast[3]); - ret->type = ret->value->type; + ret->setTee(false); + ret->finalize(); return ret; } // global var, do a store to memory @@ -1021,7 +1023,8 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { ret->align = ret->bytes; ret->ptr = builder.makeConst(Literal(int32_t(global.address))); ret->value = process(ast[3]); - ret->type = global.type; + ret->valueType = global.type; + ret->finalize(); return ret; } else if (ast[2][0] == SUB) { Ref target = ast[2]; @@ -1035,10 +1038,11 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { ret->align = view.bytes; ret->ptr = processUnshifted(target[2], view.bytes); ret->value = process(ast[3]); - ret->type = asmToWasmType(view.type); - if (ret->type != ret->value->type) { + ret->valueType = asmToWasmType(view.type); + ret->finalize(); + if (ret->valueType != ret->value->type) { // in asm.js we have some implicit coercions that we must do explicitly here - if (ret->type == f32 && ret->value->type == f64) { + if (ret->valueType == f32 && ret->value->type == f64) { auto conv = allocator.alloc<Unary>(); conv->op = DemoteFloat64; conv->value = ret->value; @@ -1273,11 +1277,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { } abort_on("bad unary", ast); } else if (what == IF) { - auto ret = allocator.alloc<If>(); - ret->condition = process(ast[1]); - ret->ifTrue = process(ast[2]); - ret->ifFalse = !!ast[3] ? process(ast[3]) : nullptr; - return ret; + return builder.makeIf(process(ast[1]), process(ast[2]), !!ast[3] ? process(ast[3]) : nullptr); } else if (what == CALL) { if (ast[1][0] == NAME) { IString name = ast[1][1]->getIString(); @@ -1330,9 +1330,10 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { // No wasm support, so use a temp local ensureI32Temp(); auto set = allocator.alloc<SetLocal>(); + set->setTee(false); set->index = function->getLocalIndex(I32_TEMP); set->value = value; - set->type = i32; + set->finalize(); auto get = [&]() { auto ret = allocator.alloc<GetLocal>(); ret->index = function->getLocalIndex(I32_TEMP); @@ -1526,6 +1527,9 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { if (breakSeeker.found == 0) { auto block = allocator.alloc<Block>(); block->list.push_back(child); + if (isConcreteWasmType(child->type)) { + block->list.push_back(builder.makeNop()); // ensure a nop at the end, so the block has guaranteed none type and no values fall through + } block->name = stop; block->finalize(); return block; diff --git a/src/ast_utils.h b/src/ast_utils.h index 3e45d0e33..17f5490a9 100644 --- a/src/ast_utils.h +++ b/src/ast_utils.h @@ -21,6 +21,7 @@ #include "wasm.h" #include "wasm-traversal.h" #include "wasm-builder.h" +#include "pass.h" namespace wasm { @@ -277,7 +278,11 @@ struct ExpressionManipulator { return builder.makeGetLocal(curr->index, curr->type); } Expression* visitSetLocal(SetLocal *curr) { - return builder.makeSetLocal(curr->index, copy(curr->value)); + if (curr->isTee()) { + return builder.makeTeeLocal(curr->index, copy(curr->value)); + } else { + return builder.makeSetLocal(curr->index, copy(curr->value)); + } } Expression* visitGetGlobal(GetGlobal *curr) { return builder.makeGetGlobal(curr->index, curr->type); @@ -289,7 +294,7 @@ struct ExpressionManipulator { return builder.makeLoad(curr->bytes, curr->signed_, curr->offset, curr->align, copy(curr->ptr), curr->type); } Expression* visitStore(Store *curr) { - return builder.makeStore(curr->bytes, curr->offset, curr->align, copy(curr->ptr), copy(curr->value)); + return builder.makeStore(curr->bytes, curr->offset, curr->align, copy(curr->ptr), copy(curr->value), curr->valueType); } Expression* visitConst(Const *curr) { return builder.makeConst(curr->value); @@ -303,6 +308,9 @@ struct ExpressionManipulator { Expression* visitSelect(Select *curr) { return builder.makeSelect(copy(curr->condition), copy(curr->ifTrue), copy(curr->ifFalse)); } + Expression* visitDrop(Drop *curr) { + return builder.makeDrop(copy(curr->value)); + } Expression* visitReturn(Return *curr) { return builder.makeReturn(copy(curr->value)); } @@ -340,7 +348,7 @@ struct ExpressionAnalyzer { for (int i = int(stack.size()) - 2; i >= 0; i--) { auto* curr = stack[i]; auto* above = stack[i + 1]; - // only if and block can drop values + // only if and block can drop values (pre-drop expression was added) FIXME if (curr->is<Block>()) { auto* block = curr->cast<Block>(); for (size_t j = 0; j < block->list.size() - 1; j++) { @@ -355,6 +363,7 @@ struct ExpressionAnalyzer { assert(above == iff->ifTrue || above == iff->ifFalse); // continue down } else { + if (curr->is<Drop>()) return false; return true; // all other node types use the result } } @@ -481,6 +490,7 @@ struct ExpressionAnalyzer { } case Expression::Id::SetLocalId: { CHECK(SetLocal, index); + CHECK(SetLocal, type); // for tee/set PUSH(SetLocal, value); break; } @@ -505,6 +515,7 @@ struct ExpressionAnalyzer { CHECK(Store, bytes); CHECK(Store, offset); CHECK(Store, align); + CHECK(Store, valueType); PUSH(Store, ptr); PUSH(Store, value); break; @@ -530,6 +541,10 @@ struct ExpressionAnalyzer { PUSH(Select, condition); break; } + case Expression::Id::DropId: { + PUSH(Drop, value); + break; + } case Expression::Id::ReturnId: { PUSH(Return, value); break; @@ -716,6 +731,7 @@ struct ExpressionAnalyzer { HASH(Store, bytes); HASH(Store, offset); HASH(Store, align); + HASH(Store, valueType); PUSH(Store, ptr); PUSH(Store, value); break; @@ -742,6 +758,10 @@ struct ExpressionAnalyzer { PUSH(Select, condition); break; } + case Expression::Id::DropId: { + PUSH(Drop, value); + break; + } case Expression::Id::ReturnId: { PUSH(Return, value); break; @@ -770,6 +790,30 @@ struct ExpressionAnalyzer { } }; +// Adds drop() operations where necessary. This lets you not worry about adding drop when +// generating code. +struct AutoDrop : public WalkerPass<PostWalker<AutoDrop, Visitor<AutoDrop>>> { + bool isFunctionParallel() override { return true; } + + Pass* create() override { return new AutoDrop; } + + void visitBlock(Block* curr) { + if (curr->list.size() <= 1) return; + for (Index i = 0; i < curr->list.size() - 1; i++) { + auto* child = curr->list[i]; + if (isConcreteWasmType(child->type)) { + curr->list[i] = Builder(*getModule()).makeDrop(child); + } + } + } + + void visitFunction(Function* curr) { + if (curr->result == none && isConcreteWasmType(curr->body->type)) { + curr->body = Builder(*getModule()).makeDrop(curr->body); + } + } +}; + } // namespace wasm #endif // wasm_ast_utils_h diff --git a/src/binaryen-c.cpp b/src/binaryen-c.cpp index 83c626eb1..71c6cad80 100644 --- a/src/binaryen-c.cpp +++ b/src/binaryen-c.cpp @@ -488,6 +488,21 @@ BinaryenExpressionRef BinaryenSetLocal(BinaryenModuleRef module, BinaryenIndex i ret->index = index; ret->value = (Expression*)value; + ret->setTee(false); + ret->finalize(); + return static_cast<Expression*>(ret); +} +BinaryenExpressionRef BinaryenTeeLocal(BinaryenModuleRef module, BinaryenIndex index, BinaryenExpressionRef value) { + auto* ret = ((Module*)module)->allocator.alloc<SetLocal>(); + + if (tracing) { + auto id = noteExpression(ret); + std::cout << " expressions[" << id << "] = BinaryenTeeLocal(the_module, " << index << ", expressions[" << expressions[value] << "]);\n"; + } + + ret->index = index; + ret->value = (Expression*)value; + ret->setTee(true); ret->finalize(); return static_cast<Expression*>(ret); } @@ -508,12 +523,12 @@ BinaryenExpressionRef BinaryenLoad(BinaryenModuleRef module, uint32_t bytes, int ret->finalize(); return static_cast<Expression*>(ret); } -BinaryenExpressionRef BinaryenStore(BinaryenModuleRef module, uint32_t bytes, uint32_t offset, uint32_t align, BinaryenExpressionRef ptr, BinaryenExpressionRef value) { +BinaryenExpressionRef BinaryenStore(BinaryenModuleRef module, uint32_t bytes, uint32_t offset, uint32_t align, BinaryenExpressionRef ptr, BinaryenExpressionRef value, BinaryenType type) { auto* ret = ((Module*)module)->allocator.alloc<Store>(); if (tracing) { auto id = noteExpression(ret); - std::cout << " expressions[" << id << "] = BinaryenStore(the_module, " << bytes << ", " << offset << ", " << align << ", expressions[" << expressions[ptr] << "], expressions[" << expressions[value] << "]);\n"; + std::cout << " expressions[" << id << "] = BinaryenStore(the_module, " << bytes << ", " << offset << ", " << align << ", expressions[" << expressions[ptr] << "], expressions[" << expressions[value] << "], " << type << ");\n"; } ret->bytes = bytes; @@ -521,6 +536,7 @@ BinaryenExpressionRef BinaryenStore(BinaryenModuleRef module, uint32_t bytes, ui ret->align = align ? align : bytes; ret->ptr = (Expression*)ptr; ret->value = (Expression*)value; + ret->valueType = WasmType(type); ret->finalize(); return static_cast<Expression*>(ret); } @@ -584,6 +600,18 @@ BinaryenExpressionRef BinaryenSelect(BinaryenModuleRef module, BinaryenExpressio ret->finalize(); return static_cast<Expression*>(ret); } +BinaryenExpressionRef BinaryenDrop(BinaryenModuleRef module, BinaryenExpressionRef value) { + auto* ret = ((Module*)module)->allocator.alloc<Drop>(); + + if (tracing) { + auto id = noteExpression(ret); + std::cout << " expressions[" << id << "] = BinaryenDrop(the_module, expressions[" << expressions[value] << "]);\n"; + } + + ret->value = (Expression*)value; + ret->finalize(); + return static_cast<Expression*>(ret); +} BinaryenExpressionRef BinaryenReturn(BinaryenModuleRef module, BinaryenExpressionRef value) { auto* ret = Builder(*((Module*)module)).makeReturn((Expression*)value); @@ -829,6 +857,17 @@ void BinaryenModuleOptimize(BinaryenModuleRef module) { passRunner.run(); } +void BinaryenModuleAutoDrop(BinaryenModuleRef module) { + if (tracing) { + std::cout << " BinaryenModuleAutoDrop(the_module);\n"; + } + + Module* wasm = (Module*)module; + PassRunner passRunner(wasm); + passRunner.add<AutoDrop>(); + passRunner.run(); +} + size_t BinaryenModuleWrite(BinaryenModuleRef module, char* output, size_t outputSize) { if (tracing) { std::cout << " // BinaryenModuleWrite\n"; diff --git a/src/binaryen-c.h b/src/binaryen-c.h index e2086072b..7ca5cd6c7 100644 --- a/src/binaryen-c.h +++ b/src/binaryen-c.h @@ -294,14 +294,16 @@ BinaryenExpressionRef BinaryenCallIndirect(BinaryenModuleRef module, BinaryenExp // for more details. BinaryenExpressionRef BinaryenGetLocal(BinaryenModuleRef module, BinaryenIndex index, BinaryenType type); BinaryenExpressionRef BinaryenSetLocal(BinaryenModuleRef module, BinaryenIndex index, BinaryenExpressionRef value); +BinaryenExpressionRef BinaryenTeeLocal(BinaryenModuleRef module, BinaryenIndex index, BinaryenExpressionRef value); // Load: align can be 0, in which case it will be the natural alignment (equal to bytes) BinaryenExpressionRef BinaryenLoad(BinaryenModuleRef module, uint32_t bytes, int8_t signed_, uint32_t offset, uint32_t align, BinaryenType type, BinaryenExpressionRef ptr); // Store: align can be 0, in which case it will be the natural alignment (equal to bytes) -BinaryenExpressionRef BinaryenStore(BinaryenModuleRef module, uint32_t bytes, uint32_t offset, uint32_t align, BinaryenExpressionRef ptr, BinaryenExpressionRef value); +BinaryenExpressionRef BinaryenStore(BinaryenModuleRef module, uint32_t bytes, uint32_t offset, uint32_t align, BinaryenExpressionRef ptr, BinaryenExpressionRef value, BinaryenType type); BinaryenExpressionRef BinaryenConst(BinaryenModuleRef module, struct BinaryenLiteral value); BinaryenExpressionRef BinaryenUnary(BinaryenModuleRef module, BinaryenOp op, BinaryenExpressionRef value); BinaryenExpressionRef BinaryenBinary(BinaryenModuleRef module, BinaryenOp op, BinaryenExpressionRef left, BinaryenExpressionRef right); BinaryenExpressionRef BinaryenSelect(BinaryenModuleRef module, BinaryenExpressionRef condition, BinaryenExpressionRef ifTrue, BinaryenExpressionRef ifFalse); +BinaryenExpressionRef BinaryenDrop(BinaryenModuleRef module, BinaryenExpressionRef value); // Return: value can be NULL BinaryenExpressionRef BinaryenReturn(BinaryenModuleRef module, BinaryenExpressionRef value); // Host: name may be NULL @@ -366,6 +368,11 @@ int BinaryenModuleValidate(BinaryenModuleRef module); // Run the standard optimization passes on the module. void BinaryenModuleOptimize(BinaryenModuleRef module); +// Auto-generate drop() operations where needed. This lets you generate code without +// worrying about where they are needed. (It is more efficient to do it yourself, +// but simpler to use autodrop). +void BinaryenModuleAutoDrop(BinaryenModuleRef module); + // Serialize a module into binary form. // @return how many bytes were written. This will be less than or equal to bufferSize size_t BinaryenModuleWrite(BinaryenModuleRef module, char* output, size_t outputSize); diff --git a/src/passes/CoalesceLocals.cpp b/src/passes/CoalesceLocals.cpp index 2c707b614..2063a20db 100644 --- a/src/passes/CoalesceLocals.cpp +++ b/src/passes/CoalesceLocals.cpp @@ -485,9 +485,19 @@ void CoalesceLocals::applyIndices(std::vector<Index>& indices, Expression* root) // in addition, we can optimize out redundant copies and ineffective sets GetLocal* get; if ((get = set->value->dynCast<GetLocal>()) && get->index == set->index) { - *action.origin = get; // further optimizations may get rid of the get, if this is in a place where the output does not matter + if (set->isTee()) { + *action.origin = get; + } else { + ExpressionManipulator::nop(set); + } } else if (!action.effective) { *action.origin = set->value; // value may have no side effects, further optimizations can eliminate it + if (!set->isTee()) { + // we need to drop it + Drop* drop = ExpressionManipulator::convert<SetLocal, Drop>(set); + drop->value = *action.origin; + *action.origin = drop; + } } } } diff --git a/src/passes/DeadCodeElimination.cpp b/src/passes/DeadCodeElimination.cpp index de1191131..aba442e71 100644 --- a/src/passes/DeadCodeElimination.cpp +++ b/src/passes/DeadCodeElimination.cpp @@ -31,6 +31,7 @@ #include <wasm.h> #include <pass.h> #include <ast_utils.h> +#include <wasm-builder.h> namespace wasm { @@ -191,6 +192,7 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V case Expression::Id::UnaryId: DELEGATE(Unary); case Expression::Id::BinaryId: DELEGATE(Binary); case Expression::Id::SelectId: DELEGATE(Select); + case Expression::Id::DropId: DELEGATE(Drop); case Expression::Id::ReturnId: DELEGATE(Return); case Expression::Id::HostId: DELEGATE(Host); case Expression::Id::NopId: DELEGATE(Nop); @@ -226,6 +228,11 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V // other things + Expression* drop(Expression* toDrop) { + if (toDrop->is<Unreachable>()) return toDrop; + return Builder(*getModule()).makeDrop(toDrop); + } + template<typename T> void handleCall(T* curr, Expression* initial) { for (Index i = 0; i < curr->operands.size(); i++) { @@ -236,11 +243,11 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V block->list.resize(newSize); Index j = 0; if (initial) { - block->list[j] = initial; + block->list[j] = drop(initial); j++; } for (; j < newSize; j++) { - block->list[j] = curr->operands[j - (initial ? 1 : 0)]; + block->list[j] = drop(curr->operands[j - (initial ? 1 : 0)]); } block->finalize(); replaceCurrent(block); @@ -288,7 +295,7 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V if (isDead(curr->value)) { auto* block = getModule()->allocator.alloc<Block>(); block->list.resize(2); - block->list[0] = curr->ptr; + block->list[0] = drop(curr->ptr); block->list[1] = curr->value; block->finalize(); replaceCurrent(block); @@ -309,7 +316,7 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V if (isDead(curr->right)) { auto* block = getModule()->allocator.alloc<Block>(); block->list.resize(2); - block->list[0] = curr->left; + block->list[0] = drop(curr->left); block->list[1] = curr->right; block->finalize(); replaceCurrent(block); @@ -324,7 +331,7 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V if (isDead(curr->ifFalse)) { auto* block = getModule()->allocator.alloc<Block>(); block->list.resize(2); - block->list[0] = curr->ifTrue; + block->list[0] = drop(curr->ifTrue); block->list[1] = curr->ifFalse; block->finalize(); replaceCurrent(block); @@ -333,8 +340,8 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V if (isDead(curr->condition)) { auto* block = getModule()->allocator.alloc<Block>(); block->list.resize(3); - block->list[0] = curr->ifTrue; - block->list[1] = curr->ifFalse; + block->list[0] = drop(curr->ifTrue); + block->list[1] = drop(curr->ifFalse); block->list[2] = curr->condition; block->finalize(); replaceCurrent(block); diff --git a/src/passes/MergeBlocks.cpp b/src/passes/MergeBlocks.cpp index 686bb5d75..4798ad28d 100644 --- a/src/passes/MergeBlocks.cpp +++ b/src/passes/MergeBlocks.cpp @@ -74,10 +74,27 @@ struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks, Visitor<MergeBloc void visitBlock(Block *curr) { bool more = true; + bool changed = false; while (more) { more = false; for (size_t i = 0; i < curr->list.size(); i++) { Block* child = curr->list[i]->dynCast<Block>(); + if (!child) { + // TODO: if we have a child that is (drop (block ..)) then we can move the drop into the block, allowing more merging, + // but we must also drop values from brs + /* + auto* drop = curr->list[i]->dynCast<Drop>(); + if (drop) { + child = drop->value->dynCast<Block>(); + if (child) { + // reuse the drop + drop->value = child->list.back(); + child->list.back() = drop; + curr->list[i] = child; + } + } + */ + } if (!child) continue; if (child->name.is()) continue; // named blocks can have breaks to them (and certainly do, if we ran RemoveUnusedNames and RemoveUnusedBrs) ExpressionList merged(getModule()->allocator); @@ -92,9 +109,11 @@ struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks, Visitor<MergeBloc } curr->list = merged; more = true; + changed = true; break; } } + if (changed) curr->finalize(); } Block* optimize(Expression* curr, Expression*& child, Block* outer = nullptr, Expression** dependency1 = nullptr, Expression** dependency2 = nullptr) { diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index 7486e07f1..ac349abc7 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -240,7 +240,12 @@ struct PrintSExpression : public Visitor<PrintSExpression> { printOpening(o, "get_local ") << printableLocal(curr->index) << ')'; } void visitSetLocal(SetLocal *curr) { - printOpening(o, "set_local ") << printableLocal(curr->index); + if (curr->isTee()) { + printOpening(o, "tee_local "); + } else { + printOpening(o, "set_local "); + } + o << printableLocal(curr->index); incIndent(); printFullLine(curr->value); decIndent(); @@ -282,7 +287,7 @@ struct PrintSExpression : public Visitor<PrintSExpression> { } void visitStore(Store *curr) { o << '('; - prepareColor(o) << printWasmType(curr->type) << ".store"; + prepareColor(o) << printWasmType(curr->valueType) << ".store"; if (curr->bytes < 4 || (curr->type == i64 && curr->bytes < 8)) { if (curr->bytes == 1) { o << '8'; @@ -467,6 +472,13 @@ struct PrintSExpression : public Visitor<PrintSExpression> { printFullLine(curr->condition); decIndent(); } + void visitDrop(Drop *curr) { + o << '('; + prepareColor(o) << "drop"; + incIndent(); + printFullLine(curr->value); + decIndent(); + } void visitReturn(Return *curr) { printOpening(o, "return"); if (!curr->value || curr->value->is<Nop>()) { diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp index d9786e62c..5315edac4 100644 --- a/src/passes/SimplifyLocals.cpp +++ b/src/passes/SimplifyLocals.cpp @@ -55,7 +55,13 @@ struct SetLocalRemover : public PostWalker<SetLocalRemover, Visitor<SetLocalRemo void visitSetLocal(SetLocal *curr) { if ((*numGetLocals)[curr->index] == 0) { - replaceCurrent(curr->value); + auto* value = curr->value; + if (curr->isTee()) { + replaceCurrent(value); + } else { + Drop* drop = ExpressionManipulator::convert<SetLocal, Drop>(curr); + drop->value = value; + } } } }; @@ -180,7 +186,10 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, auto found = sinkables.find(curr->index); if (found != sinkables.end()) { // sink it, and nop the origin - replaceCurrent(*found->second.item); + auto* set = (*found->second.item)->cast<SetLocal>(); + replaceCurrent(set); + assert(!set->isTee()); + set->setTee(true); // reuse the getlocal that is dying *found->second.item = curr; ExpressionManipulator::nop(curr); @@ -189,6 +198,16 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, } } + void visitDrop(Drop* curr) { + // collapse drop-tee into set, which can occur if a get was sunk into a tee + auto* set = curr->value->dynCast<SetLocal>(); + if (set) { + assert(set->isTee()); + set->setTee(false); + replaceCurrent(set); + } + } + void checkInvalidations(EffectAnalyzer& effects) { // TODO: this is O(bad) std::vector<Index> invalidated; @@ -225,7 +244,11 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, // store is dead, leave just the value auto found = self->sinkables.find(set->index); if (found != self->sinkables.end()) { - *found->second.item = (*found->second.item)->cast<SetLocal>()->value; + auto* previous = (*found->second.item)->cast<SetLocal>(); + assert(!previous->isTee()); + auto* previousValue = previous->value; + Drop* drop = ExpressionManipulator::convert<SetLocal, Drop>(previous); + drop->value = previousValue; self->sinkables.erase(found); self->anotherCycle = true; } @@ -236,15 +259,10 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, self->checkInvalidations(effects); } - if (set) { - // we may be a replacement for the current node, update the stack - self->expressionStack.pop_back(); - self->expressionStack.push_back(set); - if (!ExpressionAnalyzer::isResultUsed(self->expressionStack, self->getFunction())) { - Index index = set->index; - assert(self->sinkables.count(index) == 0); - self->sinkables.emplace(std::make_pair(index, SinkableInfo(currp))); - } + if (set && !set->isTee()) { + Index index = set->index; + assert(self->sinkables.count(index) == 0); + self->sinkables.emplace(std::make_pair(index, SinkableInfo(currp))); } self->expressionStack.pop_back(); diff --git a/src/passes/Vacuum.cpp b/src/passes/Vacuum.cpp index 0195427d7..1b0887181 100644 --- a/src/passes/Vacuum.cpp +++ b/src/passes/Vacuum.cpp @@ -41,6 +41,7 @@ struct Vacuum : public WalkerPass<PostWalker<Vacuum, Visitor<Vacuum>>> { case Expression::Id::BlockId: return curr; // not always needed, but handled in visitBlock() case Expression::Id::IfId: return curr; // not always needed, but handled in visitIf() case Expression::Id::LoopId: return curr; // not always needed, but handled in visitLoop() + case Expression::Id::DropId: return curr; // not always needed, but handled in visitDrop() case Expression::Id::BreakId: case Expression::Id::SwitchId: @@ -198,6 +199,13 @@ struct Vacuum : public WalkerPass<PostWalker<Vacuum, Visitor<Vacuum>>> { if (curr->body->is<Nop>()) ExpressionManipulator::nop(curr); } + void visitDrop(Drop* curr) { + // if the drop input has no side effects, it can be wiped out + if (!EffectAnalyzer(curr->value).hasSideEffects()) { + ExpressionManipulator::nop(curr); + } + } + static void visitPre(Vacuum* self, Expression** currp) { self->expressionStack.push_back(*currp); } diff --git a/src/s2wasm.h b/src/s2wasm.h index a23cbb549..5226ceda3 100644 --- a/src/s2wasm.h +++ b/src/s2wasm.h @@ -753,6 +753,7 @@ class S2WasmBuilder { set->index = func->getLocalIndex(assign); set->value = curr; set->type = curr->type; + set->setTee(false); addToBlock(set); } }; @@ -834,7 +835,7 @@ class S2WasmBuilder { auto makeStore = [&](WasmType type) { skipComma(); auto curr = allocator->alloc<Store>(); - curr->type = type; + curr->valueType = type; int32_t bytes = getInt() / CHAR_BIT; curr->bytes = bytes > 0 ? bytes : getWasmTypeSize(type); Name assign = getAssign(); @@ -849,6 +850,7 @@ class S2WasmBuilder { curr->align = 1U << getInt(attributes[0] + 8); } curr->value = inputs[1]; + curr->finalize(); setOutput(curr, assign); }; auto makeSelect = [&](WasmType type) { @@ -1146,6 +1148,7 @@ class S2WasmBuilder { Name assign = getAssign(); skipComma(); auto curr = allocator->alloc<SetLocal>(); + curr->setTee(true); curr->index = func->getLocalIndex(getAssign()); skipComma(); curr->value = getInput(); diff --git a/src/shell-interface.h b/src/shell-interface.h index f332307ad..2b0c491b9 100644 --- a/src/shell-interface.h +++ b/src/shell-interface.h @@ -167,7 +167,7 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface { } void store(Store* store, Address addr, Literal value) override { - switch (store->type) { + switch (store->valueType) { case i32: { switch (store->bytes) { case 1: memory.set<int8_t>(addr, value.geti32()); break; diff --git a/src/wasm-binary.h b/src/wasm-binary.h index 04bebddc5..791a0459b 100644 --- a/src/wasm-binary.h +++ b/src/wasm-binary.h @@ -412,6 +412,7 @@ enum ASTNodes { CallFunction = 0x16, CallIndirect = 0x17, CallImport = 0x18, + TeeLocal = 0x19, GetGlobal = 0x1a, SetGlobal = 0x1b, @@ -426,6 +427,7 @@ enum ASTNodes { TableSwitch = 0x08, Return = 0x09, Unreachable = 0x0a, + Drop = 0x0b, End = 0x0f }; @@ -939,9 +941,9 @@ public: o << int8_t(BinaryConsts::GetLocal) << U32LEB(mappedLocals[curr->index]); } void visitSetLocal(SetLocal *curr) { - if (debug) std::cerr << "zz node: SetLocal" << std::endl; + if (debug) std::cerr << "zz node: Set|TeeLocal" << std::endl; recurse(curr->value); - o << int8_t(BinaryConsts::SetLocal) << U32LEB(mappedLocals[curr->index]); + o << int8_t(curr->isTee() ? BinaryConsts::TeeLocal : BinaryConsts::SetLocal) << U32LEB(mappedLocals[curr->index]); } void visitGetGlobal(GetGlobal *curr) { if (debug) std::cerr << "zz node: GetGlobal " << (o.size() + 1) << std::endl; @@ -991,7 +993,7 @@ public: if (debug) std::cerr << "zz node: Store" << std::endl; recurse(curr->ptr); recurse(curr->value); - switch (curr->type) { + switch (curr->valueType) { case i32: { switch (curr->bytes) { case 1: o << int8_t(BinaryConsts::I32StoreMem8); break; @@ -1219,6 +1221,11 @@ public: if (debug) std::cerr << "zz node: Unreachable" << std::endl; o << int8_t(BinaryConsts::Unreachable); } + void visitDrop(Drop *curr) { + if (debug) std::cerr << "zz node: Drop" << std::endl; + recurse(curr->value); + o << int8_t(BinaryConsts::Drop); + } }; class WasmBinaryBuilder { @@ -1728,13 +1735,15 @@ public: case BinaryConsts::CallImport: visitCallImport((curr = allocator.alloc<CallImport>())->cast<CallImport>()); break; case BinaryConsts::CallIndirect: visitCallIndirect((curr = allocator.alloc<CallIndirect>())->cast<CallIndirect>()); break; case BinaryConsts::GetLocal: visitGetLocal((curr = allocator.alloc<GetLocal>())->cast<GetLocal>()); break; - case BinaryConsts::SetLocal: visitSetLocal((curr = allocator.alloc<SetLocal>())->cast<SetLocal>()); break; + case BinaryConsts::TeeLocal: + case BinaryConsts::SetLocal: visitSetLocal((curr = allocator.alloc<SetLocal>())->cast<SetLocal>(), code); break; case BinaryConsts::GetGlobal: visitGetGlobal((curr = allocator.alloc<GetGlobal>())->cast<GetGlobal>()); break; case BinaryConsts::SetGlobal: visitSetGlobal((curr = allocator.alloc<SetGlobal>())->cast<SetGlobal>()); break; case BinaryConsts::Select: visitSelect((curr = allocator.alloc<Select>())->cast<Select>()); break; case BinaryConsts::Return: visitReturn((curr = allocator.alloc<Return>())->cast<Return>()); break; case BinaryConsts::Nop: visitNop((curr = allocator.alloc<Nop>())->cast<Nop>()); break; case BinaryConsts::Unreachable: visitUnreachable((curr = allocator.alloc<Unreachable>())->cast<Unreachable>()); break; + case BinaryConsts::Drop: visitDrop((curr = allocator.alloc<Drop>())->cast<Drop>()); break; case BinaryConsts::End: case BinaryConsts::Else: curr = nullptr; break; default: { @@ -1922,12 +1931,13 @@ public: assert(curr->index < currFunction->getNumLocals()); curr->type = currFunction->getLocalType(curr->index); } - void visitSetLocal(SetLocal *curr) { - if (debug) std::cerr << "zz node: SetLocal" << std::endl; + void visitSetLocal(SetLocal *curr, uint8_t code) { + if (debug) std::cerr << "zz node: Set|TeeLocal" << std::endl; curr->index = getU32LEB(); assert(curr->index < currFunction->getNumLocals()); curr->value = popExpression(); curr->type = curr->value->type; + curr->setTee(code == BinaryConsts::TeeLocal); } void visitGetGlobal(GetGlobal *curr) { if (debug) std::cerr << "zz node: GetGlobal " << pos << std::endl; @@ -1976,21 +1986,22 @@ public: bool maybeVisitStore(Expression*& out, uint8_t code) { Store* curr; switch (code) { - case BinaryConsts::I32StoreMem8: curr = allocator.alloc<Store>(); curr->bytes = 1; curr->type = i32; break; - case BinaryConsts::I32StoreMem16: curr = allocator.alloc<Store>(); curr->bytes = 2; curr->type = i32; break; - case BinaryConsts::I32StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 4; curr->type = i32; break; - case BinaryConsts::I64StoreMem8: curr = allocator.alloc<Store>(); curr->bytes = 1; curr->type = i64; break; - case BinaryConsts::I64StoreMem16: curr = allocator.alloc<Store>(); curr->bytes = 2; curr->type = i64; break; - case BinaryConsts::I64StoreMem32: curr = allocator.alloc<Store>(); curr->bytes = 4; curr->type = i64; break; - case BinaryConsts::I64StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 8; curr->type = i64; break; - case BinaryConsts::F32StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 4; curr->type = f32; break; - case BinaryConsts::F64StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 8; curr->type = f64; break; + case BinaryConsts::I32StoreMem8: curr = allocator.alloc<Store>(); curr->bytes = 1; curr->valueType = i32; break; + case BinaryConsts::I32StoreMem16: curr = allocator.alloc<Store>(); curr->bytes = 2; curr->valueType = i32; break; + case BinaryConsts::I32StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 4; curr->valueType = i32; break; + case BinaryConsts::I64StoreMem8: curr = allocator.alloc<Store>(); curr->bytes = 1; curr->valueType = i64; break; + case BinaryConsts::I64StoreMem16: curr = allocator.alloc<Store>(); curr->bytes = 2; curr->valueType = i64; break; + case BinaryConsts::I64StoreMem32: curr = allocator.alloc<Store>(); curr->bytes = 4; curr->valueType = i64; break; + case BinaryConsts::I64StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 8; curr->valueType = i64; break; + case BinaryConsts::F32StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 4; curr->valueType = f32; break; + case BinaryConsts::F64StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 8; curr->valueType = f64; break; default: return false; } if (debug) std::cerr << "zz node: Store" << std::endl; readMemoryAccess(curr->align, curr->bytes, curr->offset); curr->value = popExpression(); curr->ptr = popExpression(); + curr->finalize(); out = curr; return true; } @@ -2175,6 +2186,10 @@ public: void visitUnreachable(Unreachable *curr) { if (debug) std::cerr << "zz node: Unreachable" << std::endl; } + void visitDrop(Drop *curr) { + if (debug) std::cerr << "zz node: Drop" << std::endl; + curr->value = popExpression(); + } }; } // namespace wasm diff --git a/src/wasm-builder.h b/src/wasm-builder.h index 22e1f9a00..3cba351a2 100644 --- a/src/wasm-builder.h +++ b/src/wasm-builder.h @@ -142,6 +142,13 @@ public: auto* ret = allocator.alloc<SetLocal>(); ret->index = index; ret->value = value; + ret->type = none; + return ret; + } + SetLocal* makeTeeLocal(Index index, Expression* value) { + auto* ret = allocator.alloc<SetLocal>(); + ret->index = index; + ret->value = value; ret->type = value->type; return ret; } @@ -164,10 +171,11 @@ public: ret->type = type; return ret; } - Store* makeStore(unsigned bytes, uint32_t offset, unsigned align, Expression *ptr, Expression *value) { + Store* makeStore(unsigned bytes, uint32_t offset, unsigned align, Expression *ptr, Expression *value, WasmType type) { auto* ret = allocator.alloc<Store>(); - ret->bytes = bytes; ret->offset = offset; ret->align = align; ret->ptr = ptr; ret->value = value; - ret->type = value->type; + ret->bytes = bytes; ret->offset = offset; ret->align = align; ret->ptr = ptr; ret->value = value; ret->valueType = type; + ret->finalize(); + assert(isConcreteWasmType(ret->value->type) ? ret->value->type == type : true); return ret; } Const* makeConst(Literal value) { @@ -205,12 +213,22 @@ public: ret->op = op; ret->nameOperand = nameOperand; ret->operands.set(operands); + ret->finalize(); return ret; } Unreachable* makeUnreachable() { return allocator.alloc<Unreachable>(); } + // Additional helpers + + Drop* makeDrop(Expression *value) { + auto* ret = allocator.alloc<Drop>(); + ret->value = value; + ret->finalize(); + return ret; + } + // Additional utility functions for building on top of nodes static Index addParam(Function* func, Name name, WasmType type) { diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index 28a3e8e6d..9d1b5e522 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -435,6 +435,12 @@ public: NOTE_EVAL1(condition.value); return condition.value.geti32() ? ifTrue : ifFalse; // ;-) } + Flow visitDrop(Drop *curr) { + NOTE_ENTER("Drop"); + Flow value = visit(curr->value); + if (value.breaking()) return value; + return Flow(); + } Flow visitReturn(Return *curr) { NOTE_ENTER("Return"); Flow flow; @@ -693,9 +699,9 @@ public: if (flow.breaking()) return flow; NOTE_EVAL1(index); NOTE_EVAL1(flow.value); - assert(flow.value.type == curr->type); + assert(curr->isTee() ? flow.value.type == curr->type : true); scope.locals[index] = flow.value; - return flow; + return curr->isTee() ? flow : Flow(); } Flow visitGetGlobal(GetGlobal *curr) { @@ -730,7 +736,7 @@ public: Flow value = visit(curr->value); if (value.breaking()) return value; instance.externalInterface->store(curr, instance.getFinalAddress(curr, ptr.value), value.value); - return value; + return Flow(); } Flow visitHost(Host *curr) { diff --git a/src/wasm-js.cpp b/src/wasm-js.cpp index 83956e47e..66dcb4864 100644 --- a/src/wasm-js.cpp +++ b/src/wasm-js.cpp @@ -411,11 +411,11 @@ extern "C" void EMSCRIPTEN_KEEPALIVE instantiate() { Module["info"].parent["HEAPU8"][addr + i] = HEAPU8[i]; } HEAP32[0] = save0; HEAP32[1] = save1; - }, (uint32_t)addr, store_->bytes, isWasmTypeFloat(store_->type), isWasmTypeFloat(store_->type) ? value.getFloat() : (double)value.getInteger()); + }, (uint32_t)addr, store_->bytes, isWasmTypeFloat(store_->valueType), isWasmTypeFloat(store_->valueType) ? value.getFloat() : (double)value.getInteger()); return; } // nicely aligned - if (!isWasmTypeFloat(store_->type)) { + if (!isWasmTypeFloat(store_->valueType)) { if (store_->bytes == 1) { EM_ASM_INT({ Module['info'].parent['HEAP8'][$0] = $1 }, addr, value.geti32()); } else if (store_->bytes == 2) { diff --git a/src/wasm-s-parser.h b/src/wasm-s-parser.h index 746d96cf2..685f26d82 100644 --- a/src/wasm-s-parser.h +++ b/src/wasm-s-parser.h @@ -568,7 +568,6 @@ private: if (str[1] == '6' && str[2] == '4' && (prefix || str[3] == 0)) return f64; } if (allowError) return none; - throw ParseException("unknown type"); abort(); } @@ -621,7 +620,7 @@ public: if (op[3] == '_') return makeBinary(s, op[4] == 'u' ? BINARY_INT(DivU) : BINARY_INT(DivS), type); if (op[3] == 0) return makeBinary(s, BINARY_FLOAT(Div), type); } - if (op[1] == 'e') return makeUnary(s, UnaryOp::DemoteFloat64, type); + if (op[1] == 'e') return makeUnary(s, UnaryOp::DemoteFloat64, type); abort_on(op); } case 'e': { @@ -739,6 +738,10 @@ public: } else if (str[1] == 'u') return makeHost(s, HostOp::CurrentMemory); abort_on(str); } + case 'd': { + if (str[1] == 'r') return makeDrop(s); + abort_on(str); + } case 'e': { if (str[1] == 'l') return makeThenOrElse(s); abort_on(str); @@ -785,6 +788,7 @@ public: } case 't': { if (str[1] == 'h') return makeThenOrElse(s); + if (str[1] == 'e' && str[2] == 'e') return makeTeeLocal(s); abort_on(str); } case 'u': { @@ -878,6 +882,13 @@ private: return ret; } + Expression* makeDrop(Element& s) { + auto ret = allocator.alloc<Drop>(); + ret->value = parseExpression(s[1]); + ret->finalize(); + return ret; + } + Expression* makeHost(Element& s, HostOp op) { auto ret = allocator.alloc<Host>(); ret->op = op; @@ -910,11 +921,18 @@ private: return ret; } + Expression* makeTeeLocal(Element& s) { + auto ret = allocator.alloc<SetLocal>(); + ret->index = getLocalIndex(*s[1]); + ret->value = parseExpression(s[2]); + ret->setTee(true); + return ret; + } Expression* makeSetLocal(Element& s) { auto ret = allocator.alloc<SetLocal>(); ret->index = getLocalIndex(*s[1]); ret->value = parseExpression(s[2]); - ret->type = currFunction->getLocalType(ret->index); + ret->setTee(false); return ret; } @@ -1061,7 +1079,7 @@ private: Expression* makeStore(Element& s, WasmType type) { const char *extra = strchr(s[0]->c_str(), '.') + 6; // after "type.store" auto ret = allocator.alloc<Store>(); - ret->type = type; + ret->valueType = type; ret->bytes = getWasmTypeSize(type); if (extra[0] == '8') { ret->bytes = 1; @@ -1092,6 +1110,7 @@ private: } ret->ptr = parseExpression(s[i]); ret->value = parseExpression(s[i+1]); + ret->finalize(); return ret; } diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h index b50ca0fb2..d6484abdb 100644 --- a/src/wasm-traversal.h +++ b/src/wasm-traversal.h @@ -53,6 +53,7 @@ struct Visitor { ReturnType visitUnary(Unary *curr) {} ReturnType visitBinary(Binary *curr) {} ReturnType visitSelect(Select *curr) {} + ReturnType visitDrop(Drop *curr) {} ReturnType visitReturn(Return *curr) {} ReturnType visitHost(Host *curr) {} ReturnType visitNop(Nop *curr) {} @@ -93,6 +94,7 @@ struct Visitor { case Expression::Id::UnaryId: DELEGATE(Unary); case Expression::Id::BinaryId: DELEGATE(Binary); case Expression::Id::SelectId: DELEGATE(Select); + case Expression::Id::DropId: DELEGATE(Drop); case Expression::Id::ReturnId: DELEGATE(Return); case Expression::Id::HostId: DELEGATE(Host); case Expression::Id::NopId: DELEGATE(Nop); @@ -132,6 +134,7 @@ struct UnifiedExpressionVisitor : public Visitor<SubType> { ReturnType visitUnary(Unary *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } ReturnType visitBinary(Binary *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } ReturnType visitSelect(Select *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } + ReturnType visitDrop(Drop *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } ReturnType visitReturn(Return *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } ReturnType visitHost(Host *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } ReturnType visitNop(Nop *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } @@ -264,14 +267,15 @@ struct Walker : public VisitorType { static void doVisitCallIndirect(SubType* self, Expression** currp) { self->visitCallIndirect((*currp)->cast<CallIndirect>()); } static void doVisitGetLocal(SubType* self, Expression** currp) { self->visitGetLocal((*currp)->cast<GetLocal>()); } static void doVisitSetLocal(SubType* self, Expression** currp) { self->visitSetLocal((*currp)->cast<SetLocal>()); } - static void doVisitGetGlobal(SubType* self, Expression** currp) { self->visitGetGlobal((*currp)->cast<GetGlobal>()); } - static void doVisitSetGlobal(SubType* self, Expression** currp) { self->visitSetGlobal((*currp)->cast<SetGlobal>()); } + static void doVisitGetGlobal(SubType* self, Expression** currp) { self->visitGetGlobal((*currp)->cast<GetGlobal>()); } + static void doVisitSetGlobal(SubType* self, Expression** currp) { self->visitSetGlobal((*currp)->cast<SetGlobal>()); } static void doVisitLoad(SubType* self, Expression** currp) { self->visitLoad((*currp)->cast<Load>()); } static void doVisitStore(SubType* self, Expression** currp) { self->visitStore((*currp)->cast<Store>()); } static void doVisitConst(SubType* self, Expression** currp) { self->visitConst((*currp)->cast<Const>()); } static void doVisitUnary(SubType* self, Expression** currp) { self->visitUnary((*currp)->cast<Unary>()); } static void doVisitBinary(SubType* self, Expression** currp) { self->visitBinary((*currp)->cast<Binary>()); } static void doVisitSelect(SubType* self, Expression** currp) { self->visitSelect((*currp)->cast<Select>()); } + static void doVisitDrop(SubType* self, Expression** currp) { self->visitDrop((*currp)->cast<Drop>()); } static void doVisitReturn(SubType* self, Expression** currp) { self->visitReturn((*currp)->cast<Return>()); } static void doVisitHost(SubType* self, Expression** currp) { self->visitHost((*currp)->cast<Host>()); } static void doVisitNop(SubType* self, Expression** currp) { self->visitNop((*currp)->cast<Nop>()); } @@ -411,6 +415,11 @@ struct PostWalker : public Walker<SubType, VisitorType> { self->pushTask(SubType::scan, &curr->cast<Select>()->ifTrue); break; } + case Expression::Id::DropId: { + self->pushTask(SubType::doVisitDrop, currp); + self->pushTask(SubType::scan, &curr->cast<Drop>()->value); + break; + } case Expression::Id::ReturnId: { self->pushTask(SubType::doVisitReturn, currp); self->maybePushTask(SubType::scan, &curr->cast<Return>()->value); diff --git a/src/wasm-validator.h b/src/wasm-validator.h index 3d9a0e1c3..b4818e253 100644 --- a/src/wasm-validator.h +++ b/src/wasm-validator.h @@ -30,7 +30,16 @@ struct WasmValidator : public PostWalker<WasmValidator, Visitor<WasmValidator>> bool valid = true; bool validateWebConstraints = false; - std::map<Name, WasmType> breakTypes; // breaks to a label must all have the same type, and the right type + struct BreakInfo { + WasmType type; + Index arity; + BreakInfo() {} + BreakInfo(WasmType type, Index arity) : type(type), arity(arity) {} + }; + + std::map<Name, std::vector<Expression*>> breakTargets; // more than one block/loop may use a label name, so stack them + std::map<Expression*, BreakInfo> breakInfos; + WasmType returnType = unreachable; // type used in returns public: @@ -42,55 +51,107 @@ public: // visitors + static void visitPreBlock(WasmValidator* self, Expression** currp) { + auto* curr = (*currp)->cast<Block>(); + if (curr->name.is()) self->breakTargets[curr->name].push_back(curr); + } + void visitBlock(Block *curr) { // if we are break'ed to, then the value must be right for us if (curr->name.is()) { - // none or unreachable means a poison value that we should ignore - if consumed, it will error - if (breakTypes.count(curr->name) > 0 && isConcreteWasmType(breakTypes[curr->name]) && isConcreteWasmType(curr->type)) { - shouldBeEqual(curr->type, breakTypes[curr->name], curr, "block+breaks must have right type if breaks return a value"); + if (breakInfos.count(curr) > 0) { + auto& info = breakInfos[curr]; + // none or unreachable means a poison value that we should ignore - if consumed, it will error + if (isConcreteWasmType(info.type) && isConcreteWasmType(curr->type)) { + shouldBeEqual(curr->type, info.type, curr, "block+breaks must have right type if breaks return a value"); + } + shouldBeTrue(info.arity != Index(-1), curr, "break arities must match"); + if (curr->list.size() > 0) { + auto last = curr->list.back()->type; + if (isConcreteWasmType(last) && info.type != unreachable) { + shouldBeEqual(last, info.type, curr, "block+breaks must have right type if block ends with a reachable value"); + } + if (last == none) { + shouldBeTrue(info.arity == Index(0), curr, "if block ends with a none, breaks cannot send a value of any type"); + } + } + } + breakTargets[curr->name].pop_back(); + } + if (curr->list.size() > 1) { + for (Index i = 0; i < curr->list.size() - 1; i++) { + if (!shouldBeTrue(!isConcreteWasmType(curr->list[i]->type), curr, "non-final block elements returning a value must be drop()ed (binaryen's autodrop option might help you)")) { + std::cerr << "(on index " << i << ":\n" << curr->list[i] << "\n), type: " << curr->list[i]->type << "\n"; + } } - breakTypes.erase(curr->name); } } - void visitIf(If *curr) { - shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32 || curr->condition->type == i64, curr, "if condition must be valid"); + + static void visitPreLoop(WasmValidator* self, Expression** currp) { + auto* curr = (*currp)->cast<Loop>(); + if (curr->in.is()) self->breakTargets[curr->in].push_back(curr); + if (curr->out.is()) self->breakTargets[curr->out].push_back(curr); } + void visitLoop(Loop *curr) { if (curr->in.is()) { - breakTypes.erase(curr->in); + breakTargets[curr->in].pop_back(); } if (curr->out.is()) { - breakTypes.erase(curr->out); + breakTargets[curr->out].pop_back(); } } - void noteBreak(Name name, Expression* value) { + + void visitIf(If *curr) { + shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32 || curr->condition->type == i64, curr, "if condition must be valid"); + } + + // override scan to add a pre and a post check task to all nodes + static void scan(WasmValidator* self, Expression** currp) { + PostWalker<WasmValidator, Visitor<WasmValidator>>::scan(self, currp); + + auto* curr = *currp; + if (curr->is<Block>()) self->pushTask(visitPreBlock, currp); + if (curr->is<Loop>()) self->pushTask(visitPreLoop, currp); + } + + void noteBreak(Name name, Expression* value, Expression* curr) { WasmType valueType = none; + Index arity = 0; if (value) { valueType = value->type; + shouldBeUnequal(valueType, none, curr, "breaks must have a valid value"); + arity = 1; } - if (breakTypes.count(name) == 0) { - breakTypes[name] = valueType; + if (!shouldBeTrue(breakTargets[name].size() > 0, curr, "all break targets must be valid")) return; + auto* target = breakTargets[name].back(); + if (breakInfos.count(target) == 0) { + breakInfos[target] = BreakInfo(valueType, arity); } else { - if (breakTypes[name] == unreachable) { - breakTypes[name] = valueType; + auto& info = breakInfos[target]; + if (info.type == unreachable) { + info.type = valueType; } else if (valueType != unreachable) { - if (valueType != breakTypes[name]) { - breakTypes[name] = none; // a poison value that must not be consumed + if (valueType != info.type) { + info.type = none; // a poison value that must not be consumed } } + if (arity != info.arity) { + info.arity = Index(-1); // a poison value + } } } void visitBreak(Break *curr) { - noteBreak(curr->name, curr->value); + noteBreak(curr->name, curr->value, curr); if (curr->condition) { shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "break condition must be i32"); } } void visitSwitch(Switch *curr) { for (auto& target : curr->targets) { - noteBreak(target, curr->value); + noteBreak(target, curr->value, curr); } - noteBreak(curr->default_, curr->value); + noteBreak(curr->default_, curr->value, curr); shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "br_table condition must be i32"); } void visitCall(Call *curr) { @@ -128,7 +189,9 @@ public: void visitSetLocal(SetLocal *curr) { shouldBeTrue(curr->index < getFunction()->getNumLocals(), curr, "set_local index must be small enough"); if (curr->value->type != unreachable) { - shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "set_local type must be correct"); + if (curr->type != none) { // tee is ok anyhow + shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "set_local type must be correct"); + } shouldBeEqual(getFunction()->getLocalType(curr->index), curr->value->type, curr, "set_local type must match function"); } } @@ -139,7 +202,8 @@ public: void visitStore(Store *curr) { validateAlignment(curr->align); shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "store pointer type must be i32"); - shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "store value type must match"); + shouldBeUnequal(curr->value->type, none, curr, "store value type must not be none"); + // TODO: enable a check that replaces this, for type being none shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "store value type must match"); } void visitBinary(Binary *curr) { if (curr->left->type != unreachable && curr->right->type != unreachable) { @@ -268,13 +332,11 @@ public: void visitFunction(Function *curr) { // if function has no result, it is ignored // if body is unreachable, it might be e.g. a return - if (curr->result != none) { - if (curr->body->type != unreachable) { - shouldBeEqual(curr->result, curr->body->type, curr->body, "function body type must match, if function returns"); - } - if (returnType != unreachable) { - shouldBeEqual(curr->result, returnType, curr->body, "function result must match, if function returns"); - } + if (curr->body->type != unreachable) { + shouldBeEqual(curr->result, curr->body->type, curr->body, "function body type must match, if function returns"); + } + if (returnType != unreachable) { + shouldBeEqual(curr->result, returnType, curr->body, "function result must match, if function returns"); } returnType = unreachable; } @@ -311,12 +373,6 @@ public: void doWalkFunction(Function* func) { PostWalker<WasmValidator, Visitor<WasmValidator>>::doWalkFunction(func); - if (!shouldBeTrue(breakTypes.size() == 0, "break targets", "all break targets must be valid")) { - for (auto& target : breakTypes) { - std::cerr << " - " << target.first << '\n'; - } - breakTypes.clear(); - } } private: diff --git a/src/wasm.h b/src/wasm.h index 68558033d..ee3e12910 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -802,7 +802,7 @@ enum UnaryOp { ConvertSInt32ToFloat32, ConvertSInt32ToFloat64, ConvertUInt32ToFloat32, ConvertUInt32ToFloat64, ConvertSInt64ToFloat32, ConvertSInt64ToFloat64, ConvertUInt64ToFloat32, ConvertUInt64ToFloat64, // int to float PromoteFloat32, // f32 to f64 DemoteFloat64, // f64 to f32 - ReinterpretInt32, ReinterpretInt64 // reinterpret bits to float + ReinterpretInt32, ReinterpretInt64, // reinterpret bits to float }; enum BinaryOp { @@ -877,6 +877,7 @@ public: UnaryId, BinaryId, SelectId, + DropId, ReturnId, HostId, NopId, @@ -929,6 +930,7 @@ inline const char *getExpressionName(Expression *curr) { case Expression::Id::UnaryId: return "unary"; case Expression::Id::BinaryId: return "binary"; case Expression::Id::SelectId: return "select"; + case Expression::Id::DropId: return "drop"; case Expression::Id::ReturnId: return "return"; case Expression::Id::HostId: return "host"; case Expression::Id::NopId: return "nop"; @@ -1108,8 +1110,13 @@ public: Index index; Expression *value; - void finalize() { - type = value->type; + bool isTee() { + return type != none; + } + + void setTee(bool is) { + if (is) type = value->type; + else type = none; } }; @@ -1150,16 +1157,17 @@ public: class Store : public SpecificExpression<Expression::StoreId> { public: - Store() {} - Store(MixedArena& allocator) {} + Store() : valueType(none) {} + Store(MixedArena& allocator) : Store() {} uint8_t bytes; Address offset; Address align; Expression *ptr, *value; + WasmType valueType; // the store never returns a value void finalize() { - type = value->type; + assert(valueType != none); // must be set } }; @@ -1312,6 +1320,14 @@ public: } }; +class Drop : public SpecificExpression<Expression::DropId> { +public: + Drop() {} + Drop(MixedArena& allocator) {} + + Expression *value; +}; + class Return : public SpecificExpression<Expression::ReturnId> { public: Return() : value(nullptr) { |