summaryrefslogtreecommitdiff
path: root/src/wasm-validator.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/wasm-validator.h')
-rw-r--r--src/wasm-validator.h111
1 files changed, 90 insertions, 21 deletions
diff --git a/src/wasm-validator.h b/src/wasm-validator.h
index f93e06524..8e94fd368 100644
--- a/src/wasm-validator.h
+++ b/src/wasm-validator.h
@@ -30,6 +30,7 @@ struct WasmValidator : public PostWalker<WasmValidator, Visitor<WasmValidator>>
bool valid;
std::map<Name, WasmType> breakTypes; // breaks to a label must all have the same type, and the right type
+ WasmType returnType = unreachable; // type used in returns
public:
bool validate(Module& module) {
@@ -51,7 +52,7 @@ public:
}
}
void visitIf(If *curr) {
- shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "if condition must be i32");
+ shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32 || curr->condition->type == i64, curr, "if condition must be valid");
}
void visitLoop(Loop *curr) {
if (curr->in.is()) {
@@ -96,29 +97,39 @@ public:
}
void visitCall(Call *curr) {
auto* target = getModule()->getFunction(curr->target);
+ shouldBeTrue(curr->operands.size() == target->params.size(), curr, "call param number must match");
for (size_t i = 0; i < curr->operands.size(); i++) {
- shouldBeTrue(curr->operands[i]->type == target->params[i], curr, "call param types must match");
+ shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, target->params[i], curr, "call param types must match");
}
}
void visitCallImport(CallImport *curr) {
auto* target = getModule()->getImport(curr->target)->type;
+ shouldBeTrue(curr->operands.size() == target->params.size(), curr, "call param number must match");
for (size_t i = 0; i < curr->operands.size(); i++) {
- shouldBeTrue(curr->operands[i]->type == target->params[i], curr, "call param types must match");
+ shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, target->params[i], curr, "call param types must match");
+ }
+ }
+ void visitCallIndirect(CallIndirect *curr) {
+ auto* type = curr->fullType;
+ shouldBeEqualOrFirstIsUnreachable(curr->target->type, i32, curr, "indirect call target must be an i32");
+ shouldBeTrue(curr->operands.size() == type->params.size(), curr, "call param number must match");
+ for (size_t i = 0; i < curr->operands.size(); i++) {
+ shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, type->params[i], curr, "call param types must match");
}
}
void visitSetLocal(SetLocal *curr) {
if (curr->value->type != unreachable) {
- shouldBeEqual(curr->type, curr->value->type, curr, "set_local type must be correct");
+ shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "set_local type must be correct");
}
}
void visitLoad(Load *curr) {
validateAlignment(curr->align);
- shouldBeEqual(curr->ptr->type, i32, curr, "load pointer type must be i32");
+ shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "load pointer type must be i32");
}
void visitStore(Store *curr) {
validateAlignment(curr->align);
- shouldBeEqual(curr->ptr->type, i32, curr, "store pointer type must be i32");
- shouldBeEqual(curr->value->type, curr->type, curr, "store value type must match");
+ shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "store pointer type must be i32");
+ shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "store value type must match");
}
void visitBinary(Binary *curr) {
if (curr->left->type != unreachable && curr->right->type != unreachable) {
@@ -126,24 +137,45 @@ public:
}
}
void visitUnary(Unary *curr) {
+ shouldBeUnequal(curr->value->type, none, curr, "unaries must not receive a none as their input");
+ switch (curr->op) {
+ case EqZInt32:
+ case EqZInt64: {
+ shouldBeEqual(curr->type, i32, curr, "eqz must return i32");
+ break;
+ }
+ default: {}
+ }
+ if (curr->value->type == unreachable) return;
switch (curr->op) {
- case Clz:
- case Ctz:
- case Popcnt:
- case Neg:
- case Abs:
- case Ceil:
- case Floor:
- case Trunc:
- case Nearest:
- case Sqrt: {
+ case ClzInt32:
+ case CtzInt32:
+ case PopcntInt32:
+ case NegFloat32:
+ case AbsFloat32:
+ case CeilFloat32:
+ case FloorFloat32:
+ case TruncFloat32:
+ case NearestFloat32:
+ case SqrtFloat32:
+ case ClzInt64:
+ case CtzInt64:
+ case PopcntInt64:
+ case NegFloat64:
+ case AbsFloat64:
+ case CeilFloat64:
+ case FloorFloat64:
+ case TruncFloat64:
+ case NearestFloat64:
+ case SqrtFloat64: {
if (curr->value->type != unreachable) {
shouldBeEqual(curr->value->type, curr->type, curr, "non-conversion unaries must return the same type");
}
break;
}
- case EqZ: {
- shouldBeEqual(curr->type, i32, curr, "relational unaries must return i32");
+ case EqZInt32:
+ case EqZInt64: {
+ shouldBeTrue(curr->value->type == i32 || curr->value->type == i64, curr, "eqz input must be i32 or i64");
break;
}
case ExtendSInt32: shouldBeEqual(curr->value->type, i32, curr, "extend type must be correct"); break;
@@ -175,15 +207,42 @@ public:
}
}
+ void visitReturn(Return* curr) {
+ if (curr->value) {
+ returnType = curr->value->type;
+ }
+ }
+
+ void visitHost(Host* curr) {
+ switch (curr->op) {
+ case GrowMemory: {
+ shouldBeEqual(curr->operands.size(), size_t(1), curr, "grow_memory must have 1 operand");
+ shouldBeEqualOrFirstIsUnreachable(curr->operands[0]->type, i32, curr, "grow_memory must have i32 operand");
+ break;
+ }
+ case PageSize:
+ case CurrentMemory:
+ case HasFeature: break;
+ default: WASM_UNREACHABLE();
+ }
+ }
+
void visitFunction(Function *curr) {
// if function has no result, it is ignored
// if body is unreachable, it might be e.g. a return
- if (curr->result != none && curr->body->type != unreachable) {
- shouldBeEqual(curr->result, curr->body->type, curr->body, "function result must match, if function returns");
+ if (curr->result != none) {
+ if (curr->body->type != unreachable) {
+ shouldBeEqual(curr->result, curr->body->type, curr->body, "function body type must match, if function returns");
+ }
+ if (returnType != unreachable) {
+ shouldBeEqual(curr->result, returnType, curr->body, "function result must match, if function returns");
+ }
}
+ returnType = unreachable;
}
void visitMemory(Memory *curr) {
shouldBeFalse(curr->initial > curr->max, "memory", "memory max >= initial");
+ shouldBeTrue(curr->max <= Memory::kMaxSize, "memory", "total memory must be <= 4GB");
size_t top = 0;
for (auto& segment : curr->segments) {
shouldBeFalse(segment.offset < top, "memory", "segment offset is small enough");
@@ -296,6 +355,16 @@ private:
}
template<typename T, typename S>
+ bool shouldBeEqualOrFirstIsUnreachable(S left, S right, T curr, const char* text) {
+ if (left != unreachable && left != right) {
+ fail() << "" << left << " != " << right << ": " << text << ", on \n" << curr << std::endl;
+ valid = false;
+ return false;
+ }
+ return true;
+ }
+
+ template<typename T, typename S>
bool shouldBeUnequal(S left, S right, T curr, const char* text) {
if (left == right) {
fail() << "" << left << " == " << right << ": " << text << ", on \n" << curr << std::endl;