summaryrefslogtreecommitdiff
path: root/src/wasm-validator.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/wasm-validator.h')
-rw-r--r--src/wasm-validator.h124
1 files changed, 90 insertions, 34 deletions
diff --git a/src/wasm-validator.h b/src/wasm-validator.h
index 3d9a0e1c3..b4818e253 100644
--- a/src/wasm-validator.h
+++ b/src/wasm-validator.h
@@ -30,7 +30,16 @@ struct WasmValidator : public PostWalker<WasmValidator, Visitor<WasmValidator>>
bool valid = true;
bool validateWebConstraints = false;
- std::map<Name, WasmType> breakTypes; // breaks to a label must all have the same type, and the right type
+ struct BreakInfo {
+ WasmType type;
+ Index arity;
+ BreakInfo() {}
+ BreakInfo(WasmType type, Index arity) : type(type), arity(arity) {}
+ };
+
+ std::map<Name, std::vector<Expression*>> breakTargets; // more than one block/loop may use a label name, so stack them
+ std::map<Expression*, BreakInfo> breakInfos;
+
WasmType returnType = unreachable; // type used in returns
public:
@@ -42,55 +51,107 @@ public:
// visitors
+ static void visitPreBlock(WasmValidator* self, Expression** currp) {
+ auto* curr = (*currp)->cast<Block>();
+ if (curr->name.is()) self->breakTargets[curr->name].push_back(curr);
+ }
+
void visitBlock(Block *curr) {
// if we are break'ed to, then the value must be right for us
if (curr->name.is()) {
- // none or unreachable means a poison value that we should ignore - if consumed, it will error
- if (breakTypes.count(curr->name) > 0 && isConcreteWasmType(breakTypes[curr->name]) && isConcreteWasmType(curr->type)) {
- shouldBeEqual(curr->type, breakTypes[curr->name], curr, "block+breaks must have right type if breaks return a value");
+ if (breakInfos.count(curr) > 0) {
+ auto& info = breakInfos[curr];
+ // none or unreachable means a poison value that we should ignore - if consumed, it will error
+ if (isConcreteWasmType(info.type) && isConcreteWasmType(curr->type)) {
+ shouldBeEqual(curr->type, info.type, curr, "block+breaks must have right type if breaks return a value");
+ }
+ shouldBeTrue(info.arity != Index(-1), curr, "break arities must match");
+ if (curr->list.size() > 0) {
+ auto last = curr->list.back()->type;
+ if (isConcreteWasmType(last) && info.type != unreachable) {
+ shouldBeEqual(last, info.type, curr, "block+breaks must have right type if block ends with a reachable value");
+ }
+ if (last == none) {
+ shouldBeTrue(info.arity == Index(0), curr, "if block ends with a none, breaks cannot send a value of any type");
+ }
+ }
+ }
+ breakTargets[curr->name].pop_back();
+ }
+ if (curr->list.size() > 1) {
+ for (Index i = 0; i < curr->list.size() - 1; i++) {
+ if (!shouldBeTrue(!isConcreteWasmType(curr->list[i]->type), curr, "non-final block elements returning a value must be drop()ed (binaryen's autodrop option might help you)")) {
+ std::cerr << "(on index " << i << ":\n" << curr->list[i] << "\n), type: " << curr->list[i]->type << "\n";
+ }
}
- breakTypes.erase(curr->name);
}
}
- void visitIf(If *curr) {
- shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32 || curr->condition->type == i64, curr, "if condition must be valid");
+
+ static void visitPreLoop(WasmValidator* self, Expression** currp) {
+ auto* curr = (*currp)->cast<Loop>();
+ if (curr->in.is()) self->breakTargets[curr->in].push_back(curr);
+ if (curr->out.is()) self->breakTargets[curr->out].push_back(curr);
}
+
void visitLoop(Loop *curr) {
if (curr->in.is()) {
- breakTypes.erase(curr->in);
+ breakTargets[curr->in].pop_back();
}
if (curr->out.is()) {
- breakTypes.erase(curr->out);
+ breakTargets[curr->out].pop_back();
}
}
- void noteBreak(Name name, Expression* value) {
+
+ void visitIf(If *curr) {
+ shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32 || curr->condition->type == i64, curr, "if condition must be valid");
+ }
+
+ // override scan to add a pre and a post check task to all nodes
+ static void scan(WasmValidator* self, Expression** currp) {
+ PostWalker<WasmValidator, Visitor<WasmValidator>>::scan(self, currp);
+
+ auto* curr = *currp;
+ if (curr->is<Block>()) self->pushTask(visitPreBlock, currp);
+ if (curr->is<Loop>()) self->pushTask(visitPreLoop, currp);
+ }
+
+ void noteBreak(Name name, Expression* value, Expression* curr) {
WasmType valueType = none;
+ Index arity = 0;
if (value) {
valueType = value->type;
+ shouldBeUnequal(valueType, none, curr, "breaks must have a valid value");
+ arity = 1;
}
- if (breakTypes.count(name) == 0) {
- breakTypes[name] = valueType;
+ if (!shouldBeTrue(breakTargets[name].size() > 0, curr, "all break targets must be valid")) return;
+ auto* target = breakTargets[name].back();
+ if (breakInfos.count(target) == 0) {
+ breakInfos[target] = BreakInfo(valueType, arity);
} else {
- if (breakTypes[name] == unreachable) {
- breakTypes[name] = valueType;
+ auto& info = breakInfos[target];
+ if (info.type == unreachable) {
+ info.type = valueType;
} else if (valueType != unreachable) {
- if (valueType != breakTypes[name]) {
- breakTypes[name] = none; // a poison value that must not be consumed
+ if (valueType != info.type) {
+ info.type = none; // a poison value that must not be consumed
}
}
+ if (arity != info.arity) {
+ info.arity = Index(-1); // a poison value
+ }
}
}
void visitBreak(Break *curr) {
- noteBreak(curr->name, curr->value);
+ noteBreak(curr->name, curr->value, curr);
if (curr->condition) {
shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "break condition must be i32");
}
}
void visitSwitch(Switch *curr) {
for (auto& target : curr->targets) {
- noteBreak(target, curr->value);
+ noteBreak(target, curr->value, curr);
}
- noteBreak(curr->default_, curr->value);
+ noteBreak(curr->default_, curr->value, curr);
shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "br_table condition must be i32");
}
void visitCall(Call *curr) {
@@ -128,7 +189,9 @@ public:
void visitSetLocal(SetLocal *curr) {
shouldBeTrue(curr->index < getFunction()->getNumLocals(), curr, "set_local index must be small enough");
if (curr->value->type != unreachable) {
- shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "set_local type must be correct");
+ if (curr->type != none) { // tee is ok anyhow
+ shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "set_local type must be correct");
+ }
shouldBeEqual(getFunction()->getLocalType(curr->index), curr->value->type, curr, "set_local type must match function");
}
}
@@ -139,7 +202,8 @@ public:
void visitStore(Store *curr) {
validateAlignment(curr->align);
shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "store pointer type must be i32");
- shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "store value type must match");
+ shouldBeUnequal(curr->value->type, none, curr, "store value type must not be none");
+ // TODO: enable a check that replaces this, for type being none shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "store value type must match");
}
void visitBinary(Binary *curr) {
if (curr->left->type != unreachable && curr->right->type != unreachable) {
@@ -268,13 +332,11 @@ public:
void visitFunction(Function *curr) {
// if function has no result, it is ignored
// if body is unreachable, it might be e.g. a return
- if (curr->result != none) {
- if (curr->body->type != unreachable) {
- shouldBeEqual(curr->result, curr->body->type, curr->body, "function body type must match, if function returns");
- }
- if (returnType != unreachable) {
- shouldBeEqual(curr->result, returnType, curr->body, "function result must match, if function returns");
- }
+ if (curr->body->type != unreachable) {
+ shouldBeEqual(curr->result, curr->body->type, curr->body, "function body type must match, if function returns");
+ }
+ if (returnType != unreachable) {
+ shouldBeEqual(curr->result, returnType, curr->body, "function result must match, if function returns");
}
returnType = unreachable;
}
@@ -311,12 +373,6 @@ public:
void doWalkFunction(Function* func) {
PostWalker<WasmValidator, Visitor<WasmValidator>>::doWalkFunction(func);
- if (!shouldBeTrue(breakTypes.size() == 0, "break targets", "all break targets must be valid")) {
- for (auto& target : breakTypes) {
- std::cerr << " - " << target.first << '\n';
- }
- breakTypes.clear();
- }
}
private: