diff options
Diffstat (limited to 'src/wasm-validator.h')
-rw-r--r-- | src/wasm-validator.h | 124 |
1 files changed, 90 insertions, 34 deletions
diff --git a/src/wasm-validator.h b/src/wasm-validator.h index 3d9a0e1c3..b4818e253 100644 --- a/src/wasm-validator.h +++ b/src/wasm-validator.h @@ -30,7 +30,16 @@ struct WasmValidator : public PostWalker<WasmValidator, Visitor<WasmValidator>> bool valid = true; bool validateWebConstraints = false; - std::map<Name, WasmType> breakTypes; // breaks to a label must all have the same type, and the right type + struct BreakInfo { + WasmType type; + Index arity; + BreakInfo() {} + BreakInfo(WasmType type, Index arity) : type(type), arity(arity) {} + }; + + std::map<Name, std::vector<Expression*>> breakTargets; // more than one block/loop may use a label name, so stack them + std::map<Expression*, BreakInfo> breakInfos; + WasmType returnType = unreachable; // type used in returns public: @@ -42,55 +51,107 @@ public: // visitors + static void visitPreBlock(WasmValidator* self, Expression** currp) { + auto* curr = (*currp)->cast<Block>(); + if (curr->name.is()) self->breakTargets[curr->name].push_back(curr); + } + void visitBlock(Block *curr) { // if we are break'ed to, then the value must be right for us if (curr->name.is()) { - // none or unreachable means a poison value that we should ignore - if consumed, it will error - if (breakTypes.count(curr->name) > 0 && isConcreteWasmType(breakTypes[curr->name]) && isConcreteWasmType(curr->type)) { - shouldBeEqual(curr->type, breakTypes[curr->name], curr, "block+breaks must have right type if breaks return a value"); + if (breakInfos.count(curr) > 0) { + auto& info = breakInfos[curr]; + // none or unreachable means a poison value that we should ignore - if consumed, it will error + if (isConcreteWasmType(info.type) && isConcreteWasmType(curr->type)) { + shouldBeEqual(curr->type, info.type, curr, "block+breaks must have right type if breaks return a value"); + } + shouldBeTrue(info.arity != Index(-1), curr, "break arities must match"); + if (curr->list.size() > 0) { + auto last = curr->list.back()->type; + if (isConcreteWasmType(last) && info.type != unreachable) { + shouldBeEqual(last, info.type, curr, "block+breaks must have right type if block ends with a reachable value"); + } + if (last == none) { + shouldBeTrue(info.arity == Index(0), curr, "if block ends with a none, breaks cannot send a value of any type"); + } + } + } + breakTargets[curr->name].pop_back(); + } + if (curr->list.size() > 1) { + for (Index i = 0; i < curr->list.size() - 1; i++) { + if (!shouldBeTrue(!isConcreteWasmType(curr->list[i]->type), curr, "non-final block elements returning a value must be drop()ed (binaryen's autodrop option might help you)")) { + std::cerr << "(on index " << i << ":\n" << curr->list[i] << "\n), type: " << curr->list[i]->type << "\n"; + } } - breakTypes.erase(curr->name); } } - void visitIf(If *curr) { - shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32 || curr->condition->type == i64, curr, "if condition must be valid"); + + static void visitPreLoop(WasmValidator* self, Expression** currp) { + auto* curr = (*currp)->cast<Loop>(); + if (curr->in.is()) self->breakTargets[curr->in].push_back(curr); + if (curr->out.is()) self->breakTargets[curr->out].push_back(curr); } + void visitLoop(Loop *curr) { if (curr->in.is()) { - breakTypes.erase(curr->in); + breakTargets[curr->in].pop_back(); } if (curr->out.is()) { - breakTypes.erase(curr->out); + breakTargets[curr->out].pop_back(); } } - void noteBreak(Name name, Expression* value) { + + void visitIf(If *curr) { + shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32 || curr->condition->type == i64, curr, "if condition must be valid"); + } + + // override scan to add a pre and a post check task to all nodes + static void scan(WasmValidator* self, Expression** currp) { + PostWalker<WasmValidator, Visitor<WasmValidator>>::scan(self, currp); + + auto* curr = *currp; + if (curr->is<Block>()) self->pushTask(visitPreBlock, currp); + if (curr->is<Loop>()) self->pushTask(visitPreLoop, currp); + } + + void noteBreak(Name name, Expression* value, Expression* curr) { WasmType valueType = none; + Index arity = 0; if (value) { valueType = value->type; + shouldBeUnequal(valueType, none, curr, "breaks must have a valid value"); + arity = 1; } - if (breakTypes.count(name) == 0) { - breakTypes[name] = valueType; + if (!shouldBeTrue(breakTargets[name].size() > 0, curr, "all break targets must be valid")) return; + auto* target = breakTargets[name].back(); + if (breakInfos.count(target) == 0) { + breakInfos[target] = BreakInfo(valueType, arity); } else { - if (breakTypes[name] == unreachable) { - breakTypes[name] = valueType; + auto& info = breakInfos[target]; + if (info.type == unreachable) { + info.type = valueType; } else if (valueType != unreachable) { - if (valueType != breakTypes[name]) { - breakTypes[name] = none; // a poison value that must not be consumed + if (valueType != info.type) { + info.type = none; // a poison value that must not be consumed } } + if (arity != info.arity) { + info.arity = Index(-1); // a poison value + } } } void visitBreak(Break *curr) { - noteBreak(curr->name, curr->value); + noteBreak(curr->name, curr->value, curr); if (curr->condition) { shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "break condition must be i32"); } } void visitSwitch(Switch *curr) { for (auto& target : curr->targets) { - noteBreak(target, curr->value); + noteBreak(target, curr->value, curr); } - noteBreak(curr->default_, curr->value); + noteBreak(curr->default_, curr->value, curr); shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "br_table condition must be i32"); } void visitCall(Call *curr) { @@ -128,7 +189,9 @@ public: void visitSetLocal(SetLocal *curr) { shouldBeTrue(curr->index < getFunction()->getNumLocals(), curr, "set_local index must be small enough"); if (curr->value->type != unreachable) { - shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "set_local type must be correct"); + if (curr->type != none) { // tee is ok anyhow + shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "set_local type must be correct"); + } shouldBeEqual(getFunction()->getLocalType(curr->index), curr->value->type, curr, "set_local type must match function"); } } @@ -139,7 +202,8 @@ public: void visitStore(Store *curr) { validateAlignment(curr->align); shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "store pointer type must be i32"); - shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "store value type must match"); + shouldBeUnequal(curr->value->type, none, curr, "store value type must not be none"); + // TODO: enable a check that replaces this, for type being none shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "store value type must match"); } void visitBinary(Binary *curr) { if (curr->left->type != unreachable && curr->right->type != unreachable) { @@ -268,13 +332,11 @@ public: void visitFunction(Function *curr) { // if function has no result, it is ignored // if body is unreachable, it might be e.g. a return - if (curr->result != none) { - if (curr->body->type != unreachable) { - shouldBeEqual(curr->result, curr->body->type, curr->body, "function body type must match, if function returns"); - } - if (returnType != unreachable) { - shouldBeEqual(curr->result, returnType, curr->body, "function result must match, if function returns"); - } + if (curr->body->type != unreachable) { + shouldBeEqual(curr->result, curr->body->type, curr->body, "function body type must match, if function returns"); + } + if (returnType != unreachable) { + shouldBeEqual(curr->result, returnType, curr->body, "function result must match, if function returns"); } returnType = unreachable; } @@ -311,12 +373,6 @@ public: void doWalkFunction(Function* func) { PostWalker<WasmValidator, Visitor<WasmValidator>>::doWalkFunction(func); - if (!shouldBeTrue(breakTypes.size() == 0, "break targets", "all break targets must be valid")) { - for (auto& target : breakTypes) { - std::cerr << " - " << target.first << '\n'; - } - breakTypes.clear(); - } } private: |