summaryrefslogtreecommitdiff
path: root/src/wasm-validator.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/wasm-validator.h')
-rw-r--r--src/wasm-validator.h208
1 files changed, 158 insertions, 50 deletions
diff --git a/src/wasm-validator.h b/src/wasm-validator.h
index 3d9a0e1c3..58a30f9a3 100644
--- a/src/wasm-validator.h
+++ b/src/wasm-validator.h
@@ -30,7 +30,16 @@ struct WasmValidator : public PostWalker<WasmValidator, Visitor<WasmValidator>>
bool valid = true;
bool validateWebConstraints = false;
- std::map<Name, WasmType> breakTypes; // breaks to a label must all have the same type, and the right type
+ struct BreakInfo {
+ WasmType type;
+ Index arity;
+ BreakInfo() {}
+ BreakInfo(WasmType type, Index arity) : type(type), arity(arity) {}
+ };
+
+ std::map<Name, std::vector<Expression*>> breakTargets; // more than one block/loop may use a label name, so stack them
+ std::map<Expression*, BreakInfo> breakInfos;
+
WasmType returnType = unreachable; // type used in returns
public:
@@ -42,55 +51,107 @@ public:
// visitors
+ static void visitPreBlock(WasmValidator* self, Expression** currp) {
+ auto* curr = (*currp)->cast<Block>();
+ if (curr->name.is()) self->breakTargets[curr->name].push_back(curr);
+ }
+
void visitBlock(Block *curr) {
// if we are break'ed to, then the value must be right for us
if (curr->name.is()) {
- // none or unreachable means a poison value that we should ignore - if consumed, it will error
- if (breakTypes.count(curr->name) > 0 && isConcreteWasmType(breakTypes[curr->name]) && isConcreteWasmType(curr->type)) {
- shouldBeEqual(curr->type, breakTypes[curr->name], curr, "block+breaks must have right type if breaks return a value");
+ if (breakInfos.count(curr) > 0) {
+ auto& info = breakInfos[curr];
+ // none or unreachable means a poison value that we should ignore - if consumed, it will error
+ if (isConcreteWasmType(info.type) && isConcreteWasmType(curr->type)) {
+ shouldBeEqual(curr->type, info.type, curr, "block+breaks must have right type if breaks return a value");
+ }
+ shouldBeTrue(info.arity != Index(-1), curr, "break arities must match");
+ if (curr->list.size() > 0) {
+ auto last = curr->list.back()->type;
+ if (isConcreteWasmType(last) && info.type != unreachable) {
+ shouldBeEqual(last, info.type, curr, "block+breaks must have right type if block ends with a reachable value");
+ }
+ if (last == none) {
+ shouldBeTrue(info.arity == Index(0), curr, "if block ends with a none, breaks cannot send a value of any type");
+ }
+ }
+ }
+ breakTargets[curr->name].pop_back();
+ }
+ if (curr->list.size() > 1) {
+ for (Index i = 0; i < curr->list.size() - 1; i++) {
+ if (!shouldBeTrue(!isConcreteWasmType(curr->list[i]->type), curr, "non-final block elements returning a value must be drop()ed (binaryen's autodrop option might help you)")) {
+ std::cerr << "(on index " << i << ":\n" << curr->list[i] << "\n), type: " << curr->list[i]->type << "\n";
+ }
}
- breakTypes.erase(curr->name);
}
}
- void visitIf(If *curr) {
- shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32 || curr->condition->type == i64, curr, "if condition must be valid");
+
+ static void visitPreLoop(WasmValidator* self, Expression** currp) {
+ auto* curr = (*currp)->cast<Loop>();
+ if (curr->name.is()) self->breakTargets[curr->name].push_back(curr);
}
+
void visitLoop(Loop *curr) {
- if (curr->in.is()) {
- breakTypes.erase(curr->in);
- }
- if (curr->out.is()) {
- breakTypes.erase(curr->out);
+ if (curr->name.is()) {
+ breakTargets[curr->name].pop_back();
+ if (breakInfos.count(curr) > 0) {
+ auto& info = breakInfos[curr];
+ shouldBeEqual(info.arity, Index(0), curr, "breaks to a loop cannot pass a value");
+ }
}
}
- void noteBreak(Name name, Expression* value) {
+
+ void visitIf(If *curr) {
+ shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32 || curr->condition->type == i64, curr, "if condition must be valid");
+ }
+
+ // override scan to add a pre and a post check task to all nodes
+ static void scan(WasmValidator* self, Expression** currp) {
+ PostWalker<WasmValidator, Visitor<WasmValidator>>::scan(self, currp);
+
+ auto* curr = *currp;
+ if (curr->is<Block>()) self->pushTask(visitPreBlock, currp);
+ if (curr->is<Loop>()) self->pushTask(visitPreLoop, currp);
+ }
+
+ void noteBreak(Name name, Expression* value, Expression* curr) {
WasmType valueType = none;
+ Index arity = 0;
if (value) {
valueType = value->type;
+ shouldBeUnequal(valueType, none, curr, "breaks must have a valid value");
+ arity = 1;
}
- if (breakTypes.count(name) == 0) {
- breakTypes[name] = valueType;
+ if (!shouldBeTrue(breakTargets[name].size() > 0, curr, "all break targets must be valid")) return;
+ auto* target = breakTargets[name].back();
+ if (breakInfos.count(target) == 0) {
+ breakInfos[target] = BreakInfo(valueType, arity);
} else {
- if (breakTypes[name] == unreachable) {
- breakTypes[name] = valueType;
+ auto& info = breakInfos[target];
+ if (info.type == unreachable) {
+ info.type = valueType;
} else if (valueType != unreachable) {
- if (valueType != breakTypes[name]) {
- breakTypes[name] = none; // a poison value that must not be consumed
+ if (valueType != info.type) {
+ info.type = none; // a poison value that must not be consumed
}
}
+ if (arity != info.arity) {
+ info.arity = Index(-1); // a poison value
+ }
}
}
void visitBreak(Break *curr) {
- noteBreak(curr->name, curr->value);
+ noteBreak(curr->name, curr->value, curr);
if (curr->condition) {
shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "break condition must be i32");
}
}
void visitSwitch(Switch *curr) {
for (auto& target : curr->targets) {
- noteBreak(target, curr->value);
+ noteBreak(target, curr->value, curr);
}
- noteBreak(curr->default_, curr->value);
+ noteBreak(curr->default_, curr->value, curr);
shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "br_table condition must be i32");
}
void visitCall(Call *curr) {
@@ -106,7 +167,7 @@ public:
void visitCallImport(CallImport *curr) {
auto* import = getModule()->checkImport(curr->target);
if (!shouldBeTrue(!!import, curr, "call_import target must exist")) return;
- auto* type = import->type;
+ auto* type = import->functionType;
if (!shouldBeTrue(curr->operands.size() == type->params.size(), curr, "call param number must match")) return;
for (size_t i = 0; i < curr->operands.size(); i++) {
if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, type->params[i], curr, "call param types must match")) {
@@ -115,6 +176,7 @@ public:
}
}
void visitCallIndirect(CallIndirect *curr) {
+ shouldBeTrue(getModule()->table.segments.size() > 0, curr, "no table");
auto* type = getModule()->checkFunctionType(curr->fullType);
if (!shouldBeTrue(!!type, curr, "call_indirect type must exist")) return;
shouldBeEqualOrFirstIsUnreachable(curr->target->type, i32, curr, "indirect call target must be an i32");
@@ -128,18 +190,21 @@ public:
void visitSetLocal(SetLocal *curr) {
shouldBeTrue(curr->index < getFunction()->getNumLocals(), curr, "set_local index must be small enough");
if (curr->value->type != unreachable) {
- shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "set_local type must be correct");
+ if (curr->type != none) { // tee is ok anyhow
+ shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "set_local type must be correct");
+ }
shouldBeEqual(getFunction()->getLocalType(curr->index), curr->value->type, curr, "set_local type must match function");
}
}
void visitLoad(Load *curr) {
- validateAlignment(curr->align);
+ validateAlignment(curr->align, curr->type, curr->bytes);
shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "load pointer type must be i32");
}
void visitStore(Store *curr) {
- validateAlignment(curr->align);
+ validateAlignment(curr->align, curr->type, curr->bytes);
shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "store pointer type must be i32");
- shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "store value type must match");
+ shouldBeUnequal(curr->value->type, none, curr, "store value type must not be none");
+ shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->valueType, curr, "store value type must match");
}
void visitBinary(Binary *curr) {
if (curr->left->type != unreachable && curr->right->type != unreachable) {
@@ -216,6 +281,10 @@ public:
default: abort();
}
}
+ void visitSelect(Select* curr) {
+ shouldBeUnequal(curr->ifTrue->type, none, curr, "select left must be valid");
+ shouldBeUnequal(curr->ifFalse->type, none, curr, "select right must be valid");
+ }
void visitReturn(Return* curr) {
if (curr->value) {
@@ -245,9 +314,11 @@ public:
void visitImport(Import* curr) {
if (!validateWebConstraints) return;
- shouldBeUnequal(curr->type->result, i64, curr->name, "Imported function must not have i64 return type");
- for (WasmType param : curr->type->params) {
- shouldBeUnequal(param, i64, curr->name, "Imported function must not have i64 parameters");
+ if (curr->kind == Import::Function) {
+ shouldBeUnequal(curr->functionType->result, i64, curr->name, "Imported function must not have i64 return type");
+ for (WasmType param : curr->functionType->params) {
+ shouldBeUnequal(param, i64, curr->name, "Imported function must not have i64 parameters");
+ }
}
}
@@ -262,39 +333,66 @@ public:
void visitGlobal(Global* curr) {
shouldBeTrue(curr->init->is<Const>(), curr->name, "global init must be valid");
- shouldBeEqual(curr->type, curr->init->type, curr, "global init must have correct type");
+ shouldBeEqual(curr->type, curr->init->type, nullptr, "global init must have correct type");
}
void visitFunction(Function *curr) {
// if function has no result, it is ignored
// if body is unreachable, it might be e.g. a return
- if (curr->result != none) {
- if (curr->body->type != unreachable) {
- shouldBeEqual(curr->result, curr->body->type, curr->body, "function body type must match, if function returns");
- }
+ if (curr->body->type != unreachable) {
+ shouldBeEqual(curr->result, curr->body->type, curr->body, "function body type must match, if function returns");
+ }
+ if (curr->result != none) { // TODO: over previous too?
if (returnType != unreachable) {
shouldBeEqual(curr->result, returnType, curr->body, "function result must match, if function returns");
}
}
returnType = unreachable;
}
+
+ bool isConstant(Expression* curr) {
+ return curr->is<Const>() || curr->is<GetGlobal>();
+ }
+
void visitMemory(Memory *curr) {
shouldBeFalse(curr->initial > curr->max, "memory", "memory max >= initial");
shouldBeTrue(curr->max <= Memory::kMaxSize, "memory", "max memory must be <= 4GB");
+ Index mustBeGreaterOrEqual = 0;
+ for (auto& segment : curr->segments) {
+ if (!shouldBeEqual(segment.offset->type, i32, segment.offset, "segment offset should be i32")) continue;
+ shouldBeTrue(isConstant(segment.offset), segment.offset, "segment offset should be constant");
+ Index size = segment.data.size();
+ shouldBeTrue(size <= curr->initial * Memory::kPageSize, segment.data.size(), "segment size should fit in memory");
+ if (segment.offset->is<Const>()) {
+ Index start = segment.offset->cast<Const>()->value.geti32();
+ Index end = start + size;
+ shouldBeTrue(end <= curr->initial * Memory::kPageSize, segment.data.size(), "segment size should fit in memory");
+ shouldBeTrue(start >= mustBeGreaterOrEqual, segment.data.size(), "segment size should fit in memory");
+ mustBeGreaterOrEqual = end;
+ }
+ }
+ }
+ void visitTable(Table* curr) {
+ for (auto& segment : curr->segments) {
+ shouldBeEqual(segment.offset->type, i32, segment.offset, "segment offset should be i32");
+ shouldBeTrue(isConstant(segment.offset), segment.offset, "segment offset should be constant");
+ }
}
void visitModule(Module *curr) {
// exports
std::set<Name> exportNames;
for (auto& exp : curr->exports) {
Name name = exp->value;
- bool found = false;
- for (auto& func : curr->functions) {
- if (func->name == name) {
- found = true;
- break;
+ if (exp->kind == Export::Function) {
+ bool found = false;
+ for (auto& func : curr->functions) {
+ if (func->name == name) {
+ found = true;
+ break;
+ }
}
+ shouldBeTrue(found, name, "module exports must be found");
}
- shouldBeTrue(found, name, "module exports must be found");
Name exportName = exp->name;
shouldBeFalse(exportNames.count(exportName) > 0, exportName, "module exports must be unique");
exportNames.insert(exportName);
@@ -311,12 +409,6 @@ public:
void doWalkFunction(Function* func) {
PostWalker<WasmValidator, Visitor<WasmValidator>>::doWalkFunction(func);
- if (!shouldBeTrue(breakTypes.size() == 0, "break targets", "all break targets must be valid")) {
- for (auto& target : breakTypes) {
- std::cerr << " - " << target.first << '\n';
- }
- breakTypes.clear();
- }
}
private:
@@ -360,7 +452,8 @@ private:
template<typename T, typename S>
bool shouldBeEqual(S left, S right, T curr, const char* text) {
if (left != right) {
- fail() << "" << left << " != " << right << ": " << text << ", on \n" << curr << std::endl;
+ fail() << "" << left << " != " << right << ": " << text << ", on \n";
+ WasmPrinter::printExpression(curr, std::cerr, false, true) << std::endl;
valid = false;
return false;
}
@@ -379,7 +472,8 @@ private:
template<typename T, typename S>
bool shouldBeEqualOrFirstIsUnreachable(S left, S right, T curr, const char* text) {
if (left != unreachable && left != right) {
- fail() << "" << left << " != " << right << ": " << text << ", on \n" << curr << std::endl;
+ fail() << "" << left << " != " << right << ": " << text << ", on \n";
+ WasmPrinter::printExpression(curr, std::cerr, false, true) << std::endl;
valid = false;
return false;
}
@@ -396,7 +490,7 @@ private:
return true;
}
- void validateAlignment(size_t align) {
+ void validateAlignment(size_t align, WasmType type, Index bytes) {
switch (align) {
case 1:
case 2:
@@ -408,6 +502,20 @@ private:
break;
}
}
+ shouldBeTrue(align <= bytes, align, "alignment must not exceed natural");
+ switch (type) {
+ case i32:
+ case f32: {
+ shouldBeTrue(align <= 4, align, "alignment must not exceed natural");
+ break;
+ }
+ case i64:
+ case f64: {
+ shouldBeTrue(align <= 8, align, "alignment must not exceed natural");
+ break;
+ }
+ default: {}
+ }
}
};