diff options
Diffstat (limited to 'src/wasm-validator.h')
-rw-r--r-- | src/wasm-validator.h | 111 |
1 files changed, 90 insertions, 21 deletions
diff --git a/src/wasm-validator.h b/src/wasm-validator.h index f93e06524..8e94fd368 100644 --- a/src/wasm-validator.h +++ b/src/wasm-validator.h @@ -30,6 +30,7 @@ struct WasmValidator : public PostWalker<WasmValidator, Visitor<WasmValidator>> bool valid; std::map<Name, WasmType> breakTypes; // breaks to a label must all have the same type, and the right type + WasmType returnType = unreachable; // type used in returns public: bool validate(Module& module) { @@ -51,7 +52,7 @@ public: } } void visitIf(If *curr) { - shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "if condition must be i32"); + shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32 || curr->condition->type == i64, curr, "if condition must be valid"); } void visitLoop(Loop *curr) { if (curr->in.is()) { @@ -96,29 +97,39 @@ public: } void visitCall(Call *curr) { auto* target = getModule()->getFunction(curr->target); + shouldBeTrue(curr->operands.size() == target->params.size(), curr, "call param number must match"); for (size_t i = 0; i < curr->operands.size(); i++) { - shouldBeTrue(curr->operands[i]->type == target->params[i], curr, "call param types must match"); + shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, target->params[i], curr, "call param types must match"); } } void visitCallImport(CallImport *curr) { auto* target = getModule()->getImport(curr->target)->type; + shouldBeTrue(curr->operands.size() == target->params.size(), curr, "call param number must match"); for (size_t i = 0; i < curr->operands.size(); i++) { - shouldBeTrue(curr->operands[i]->type == target->params[i], curr, "call param types must match"); + shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, target->params[i], curr, "call param types must match"); + } + } + void visitCallIndirect(CallIndirect *curr) { + auto* type = curr->fullType; + shouldBeEqualOrFirstIsUnreachable(curr->target->type, i32, curr, "indirect call target must be an i32"); + shouldBeTrue(curr->operands.size() == type->params.size(), curr, "call param number must match"); + for (size_t i = 0; i < curr->operands.size(); i++) { + shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, type->params[i], curr, "call param types must match"); } } void visitSetLocal(SetLocal *curr) { if (curr->value->type != unreachable) { - shouldBeEqual(curr->type, curr->value->type, curr, "set_local type must be correct"); + shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "set_local type must be correct"); } } void visitLoad(Load *curr) { validateAlignment(curr->align); - shouldBeEqual(curr->ptr->type, i32, curr, "load pointer type must be i32"); + shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "load pointer type must be i32"); } void visitStore(Store *curr) { validateAlignment(curr->align); - shouldBeEqual(curr->ptr->type, i32, curr, "store pointer type must be i32"); - shouldBeEqual(curr->value->type, curr->type, curr, "store value type must match"); + shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "store pointer type must be i32"); + shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "store value type must match"); } void visitBinary(Binary *curr) { if (curr->left->type != unreachable && curr->right->type != unreachable) { @@ -126,24 +137,45 @@ public: } } void visitUnary(Unary *curr) { + shouldBeUnequal(curr->value->type, none, curr, "unaries must not receive a none as their input"); + switch (curr->op) { + case EqZInt32: + case EqZInt64: { + shouldBeEqual(curr->type, i32, curr, "eqz must return i32"); + break; + } + default: {} + } + if (curr->value->type == unreachable) return; switch (curr->op) { - case Clz: - case Ctz: - case Popcnt: - case Neg: - case Abs: - case Ceil: - case Floor: - case Trunc: - case Nearest: - case Sqrt: { + case ClzInt32: + case CtzInt32: + case PopcntInt32: + case NegFloat32: + case AbsFloat32: + case CeilFloat32: + case FloorFloat32: + case TruncFloat32: + case NearestFloat32: + case SqrtFloat32: + case ClzInt64: + case CtzInt64: + case PopcntInt64: + case NegFloat64: + case AbsFloat64: + case CeilFloat64: + case FloorFloat64: + case TruncFloat64: + case NearestFloat64: + case SqrtFloat64: { if (curr->value->type != unreachable) { shouldBeEqual(curr->value->type, curr->type, curr, "non-conversion unaries must return the same type"); } break; } - case EqZ: { - shouldBeEqual(curr->type, i32, curr, "relational unaries must return i32"); + case EqZInt32: + case EqZInt64: { + shouldBeTrue(curr->value->type == i32 || curr->value->type == i64, curr, "eqz input must be i32 or i64"); break; } case ExtendSInt32: shouldBeEqual(curr->value->type, i32, curr, "extend type must be correct"); break; @@ -175,15 +207,42 @@ public: } } + void visitReturn(Return* curr) { + if (curr->value) { + returnType = curr->value->type; + } + } + + void visitHost(Host* curr) { + switch (curr->op) { + case GrowMemory: { + shouldBeEqual(curr->operands.size(), size_t(1), curr, "grow_memory must have 1 operand"); + shouldBeEqualOrFirstIsUnreachable(curr->operands[0]->type, i32, curr, "grow_memory must have i32 operand"); + break; + } + case PageSize: + case CurrentMemory: + case HasFeature: break; + default: WASM_UNREACHABLE(); + } + } + void visitFunction(Function *curr) { // if function has no result, it is ignored // if body is unreachable, it might be e.g. a return - if (curr->result != none && curr->body->type != unreachable) { - shouldBeEqual(curr->result, curr->body->type, curr->body, "function result must match, if function returns"); + if (curr->result != none) { + if (curr->body->type != unreachable) { + shouldBeEqual(curr->result, curr->body->type, curr->body, "function body type must match, if function returns"); + } + if (returnType != unreachable) { + shouldBeEqual(curr->result, returnType, curr->body, "function result must match, if function returns"); + } } + returnType = unreachable; } void visitMemory(Memory *curr) { shouldBeFalse(curr->initial > curr->max, "memory", "memory max >= initial"); + shouldBeTrue(curr->max <= Memory::kMaxSize, "memory", "total memory must be <= 4GB"); size_t top = 0; for (auto& segment : curr->segments) { shouldBeFalse(segment.offset < top, "memory", "segment offset is small enough"); @@ -296,6 +355,16 @@ private: } template<typename T, typename S> + bool shouldBeEqualOrFirstIsUnreachable(S left, S right, T curr, const char* text) { + if (left != unreachable && left != right) { + fail() << "" << left << " != " << right << ": " << text << ", on \n" << curr << std::endl; + valid = false; + return false; + } + return true; + } + + template<typename T, typename S> bool shouldBeUnequal(S left, S right, T curr, const char* text) { if (left == right) { fail() << "" << left << " == " << right << ": " << text << ", on \n" << curr << std::endl; |