diff options
Diffstat (limited to 'src/wasm-validator.h')
-rw-r--r-- | src/wasm-validator.h | 208 |
1 files changed, 158 insertions, 50 deletions
diff --git a/src/wasm-validator.h b/src/wasm-validator.h index 3d9a0e1c3..58a30f9a3 100644 --- a/src/wasm-validator.h +++ b/src/wasm-validator.h @@ -30,7 +30,16 @@ struct WasmValidator : public PostWalker<WasmValidator, Visitor<WasmValidator>> bool valid = true; bool validateWebConstraints = false; - std::map<Name, WasmType> breakTypes; // breaks to a label must all have the same type, and the right type + struct BreakInfo { + WasmType type; + Index arity; + BreakInfo() {} + BreakInfo(WasmType type, Index arity) : type(type), arity(arity) {} + }; + + std::map<Name, std::vector<Expression*>> breakTargets; // more than one block/loop may use a label name, so stack them + std::map<Expression*, BreakInfo> breakInfos; + WasmType returnType = unreachable; // type used in returns public: @@ -42,55 +51,107 @@ public: // visitors + static void visitPreBlock(WasmValidator* self, Expression** currp) { + auto* curr = (*currp)->cast<Block>(); + if (curr->name.is()) self->breakTargets[curr->name].push_back(curr); + } + void visitBlock(Block *curr) { // if we are break'ed to, then the value must be right for us if (curr->name.is()) { - // none or unreachable means a poison value that we should ignore - if consumed, it will error - if (breakTypes.count(curr->name) > 0 && isConcreteWasmType(breakTypes[curr->name]) && isConcreteWasmType(curr->type)) { - shouldBeEqual(curr->type, breakTypes[curr->name], curr, "block+breaks must have right type if breaks return a value"); + if (breakInfos.count(curr) > 0) { + auto& info = breakInfos[curr]; + // none or unreachable means a poison value that we should ignore - if consumed, it will error + if (isConcreteWasmType(info.type) && isConcreteWasmType(curr->type)) { + shouldBeEqual(curr->type, info.type, curr, "block+breaks must have right type if breaks return a value"); + } + shouldBeTrue(info.arity != Index(-1), curr, "break arities must match"); + if (curr->list.size() > 0) { + auto last = curr->list.back()->type; + if (isConcreteWasmType(last) && info.type != unreachable) { + shouldBeEqual(last, info.type, curr, "block+breaks must have right type if block ends with a reachable value"); + } + if (last == none) { + shouldBeTrue(info.arity == Index(0), curr, "if block ends with a none, breaks cannot send a value of any type"); + } + } + } + breakTargets[curr->name].pop_back(); + } + if (curr->list.size() > 1) { + for (Index i = 0; i < curr->list.size() - 1; i++) { + if (!shouldBeTrue(!isConcreteWasmType(curr->list[i]->type), curr, "non-final block elements returning a value must be drop()ed (binaryen's autodrop option might help you)")) { + std::cerr << "(on index " << i << ":\n" << curr->list[i] << "\n), type: " << curr->list[i]->type << "\n"; + } } - breakTypes.erase(curr->name); } } - void visitIf(If *curr) { - shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32 || curr->condition->type == i64, curr, "if condition must be valid"); + + static void visitPreLoop(WasmValidator* self, Expression** currp) { + auto* curr = (*currp)->cast<Loop>(); + if (curr->name.is()) self->breakTargets[curr->name].push_back(curr); } + void visitLoop(Loop *curr) { - if (curr->in.is()) { - breakTypes.erase(curr->in); - } - if (curr->out.is()) { - breakTypes.erase(curr->out); + if (curr->name.is()) { + breakTargets[curr->name].pop_back(); + if (breakInfos.count(curr) > 0) { + auto& info = breakInfos[curr]; + shouldBeEqual(info.arity, Index(0), curr, "breaks to a loop cannot pass a value"); + } } } - void noteBreak(Name name, Expression* value) { + + void visitIf(If *curr) { + shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32 || curr->condition->type == i64, curr, "if condition must be valid"); + } + + // override scan to add a pre and a post check task to all nodes + static void scan(WasmValidator* self, Expression** currp) { + PostWalker<WasmValidator, Visitor<WasmValidator>>::scan(self, currp); + + auto* curr = *currp; + if (curr->is<Block>()) self->pushTask(visitPreBlock, currp); + if (curr->is<Loop>()) self->pushTask(visitPreLoop, currp); + } + + void noteBreak(Name name, Expression* value, Expression* curr) { WasmType valueType = none; + Index arity = 0; if (value) { valueType = value->type; + shouldBeUnequal(valueType, none, curr, "breaks must have a valid value"); + arity = 1; } - if (breakTypes.count(name) == 0) { - breakTypes[name] = valueType; + if (!shouldBeTrue(breakTargets[name].size() > 0, curr, "all break targets must be valid")) return; + auto* target = breakTargets[name].back(); + if (breakInfos.count(target) == 0) { + breakInfos[target] = BreakInfo(valueType, arity); } else { - if (breakTypes[name] == unreachable) { - breakTypes[name] = valueType; + auto& info = breakInfos[target]; + if (info.type == unreachable) { + info.type = valueType; } else if (valueType != unreachable) { - if (valueType != breakTypes[name]) { - breakTypes[name] = none; // a poison value that must not be consumed + if (valueType != info.type) { + info.type = none; // a poison value that must not be consumed } } + if (arity != info.arity) { + info.arity = Index(-1); // a poison value + } } } void visitBreak(Break *curr) { - noteBreak(curr->name, curr->value); + noteBreak(curr->name, curr->value, curr); if (curr->condition) { shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "break condition must be i32"); } } void visitSwitch(Switch *curr) { for (auto& target : curr->targets) { - noteBreak(target, curr->value); + noteBreak(target, curr->value, curr); } - noteBreak(curr->default_, curr->value); + noteBreak(curr->default_, curr->value, curr); shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "br_table condition must be i32"); } void visitCall(Call *curr) { @@ -106,7 +167,7 @@ public: void visitCallImport(CallImport *curr) { auto* import = getModule()->checkImport(curr->target); if (!shouldBeTrue(!!import, curr, "call_import target must exist")) return; - auto* type = import->type; + auto* type = import->functionType; if (!shouldBeTrue(curr->operands.size() == type->params.size(), curr, "call param number must match")) return; for (size_t i = 0; i < curr->operands.size(); i++) { if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, type->params[i], curr, "call param types must match")) { @@ -115,6 +176,7 @@ public: } } void visitCallIndirect(CallIndirect *curr) { + shouldBeTrue(getModule()->table.segments.size() > 0, curr, "no table"); auto* type = getModule()->checkFunctionType(curr->fullType); if (!shouldBeTrue(!!type, curr, "call_indirect type must exist")) return; shouldBeEqualOrFirstIsUnreachable(curr->target->type, i32, curr, "indirect call target must be an i32"); @@ -128,18 +190,21 @@ public: void visitSetLocal(SetLocal *curr) { shouldBeTrue(curr->index < getFunction()->getNumLocals(), curr, "set_local index must be small enough"); if (curr->value->type != unreachable) { - shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "set_local type must be correct"); + if (curr->type != none) { // tee is ok anyhow + shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "set_local type must be correct"); + } shouldBeEqual(getFunction()->getLocalType(curr->index), curr->value->type, curr, "set_local type must match function"); } } void visitLoad(Load *curr) { - validateAlignment(curr->align); + validateAlignment(curr->align, curr->type, curr->bytes); shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "load pointer type must be i32"); } void visitStore(Store *curr) { - validateAlignment(curr->align); + validateAlignment(curr->align, curr->type, curr->bytes); shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "store pointer type must be i32"); - shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "store value type must match"); + shouldBeUnequal(curr->value->type, none, curr, "store value type must not be none"); + shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->valueType, curr, "store value type must match"); } void visitBinary(Binary *curr) { if (curr->left->type != unreachable && curr->right->type != unreachable) { @@ -216,6 +281,10 @@ public: default: abort(); } } + void visitSelect(Select* curr) { + shouldBeUnequal(curr->ifTrue->type, none, curr, "select left must be valid"); + shouldBeUnequal(curr->ifFalse->type, none, curr, "select right must be valid"); + } void visitReturn(Return* curr) { if (curr->value) { @@ -245,9 +314,11 @@ public: void visitImport(Import* curr) { if (!validateWebConstraints) return; - shouldBeUnequal(curr->type->result, i64, curr->name, "Imported function must not have i64 return type"); - for (WasmType param : curr->type->params) { - shouldBeUnequal(param, i64, curr->name, "Imported function must not have i64 parameters"); + if (curr->kind == Import::Function) { + shouldBeUnequal(curr->functionType->result, i64, curr->name, "Imported function must not have i64 return type"); + for (WasmType param : curr->functionType->params) { + shouldBeUnequal(param, i64, curr->name, "Imported function must not have i64 parameters"); + } } } @@ -262,39 +333,66 @@ public: void visitGlobal(Global* curr) { shouldBeTrue(curr->init->is<Const>(), curr->name, "global init must be valid"); - shouldBeEqual(curr->type, curr->init->type, curr, "global init must have correct type"); + shouldBeEqual(curr->type, curr->init->type, nullptr, "global init must have correct type"); } void visitFunction(Function *curr) { // if function has no result, it is ignored // if body is unreachable, it might be e.g. a return - if (curr->result != none) { - if (curr->body->type != unreachable) { - shouldBeEqual(curr->result, curr->body->type, curr->body, "function body type must match, if function returns"); - } + if (curr->body->type != unreachable) { + shouldBeEqual(curr->result, curr->body->type, curr->body, "function body type must match, if function returns"); + } + if (curr->result != none) { // TODO: over previous too? if (returnType != unreachable) { shouldBeEqual(curr->result, returnType, curr->body, "function result must match, if function returns"); } } returnType = unreachable; } + + bool isConstant(Expression* curr) { + return curr->is<Const>() || curr->is<GetGlobal>(); + } + void visitMemory(Memory *curr) { shouldBeFalse(curr->initial > curr->max, "memory", "memory max >= initial"); shouldBeTrue(curr->max <= Memory::kMaxSize, "memory", "max memory must be <= 4GB"); + Index mustBeGreaterOrEqual = 0; + for (auto& segment : curr->segments) { + if (!shouldBeEqual(segment.offset->type, i32, segment.offset, "segment offset should be i32")) continue; + shouldBeTrue(isConstant(segment.offset), segment.offset, "segment offset should be constant"); + Index size = segment.data.size(); + shouldBeTrue(size <= curr->initial * Memory::kPageSize, segment.data.size(), "segment size should fit in memory"); + if (segment.offset->is<Const>()) { + Index start = segment.offset->cast<Const>()->value.geti32(); + Index end = start + size; + shouldBeTrue(end <= curr->initial * Memory::kPageSize, segment.data.size(), "segment size should fit in memory"); + shouldBeTrue(start >= mustBeGreaterOrEqual, segment.data.size(), "segment size should fit in memory"); + mustBeGreaterOrEqual = end; + } + } + } + void visitTable(Table* curr) { + for (auto& segment : curr->segments) { + shouldBeEqual(segment.offset->type, i32, segment.offset, "segment offset should be i32"); + shouldBeTrue(isConstant(segment.offset), segment.offset, "segment offset should be constant"); + } } void visitModule(Module *curr) { // exports std::set<Name> exportNames; for (auto& exp : curr->exports) { Name name = exp->value; - bool found = false; - for (auto& func : curr->functions) { - if (func->name == name) { - found = true; - break; + if (exp->kind == Export::Function) { + bool found = false; + for (auto& func : curr->functions) { + if (func->name == name) { + found = true; + break; + } } + shouldBeTrue(found, name, "module exports must be found"); } - shouldBeTrue(found, name, "module exports must be found"); Name exportName = exp->name; shouldBeFalse(exportNames.count(exportName) > 0, exportName, "module exports must be unique"); exportNames.insert(exportName); @@ -311,12 +409,6 @@ public: void doWalkFunction(Function* func) { PostWalker<WasmValidator, Visitor<WasmValidator>>::doWalkFunction(func); - if (!shouldBeTrue(breakTypes.size() == 0, "break targets", "all break targets must be valid")) { - for (auto& target : breakTypes) { - std::cerr << " - " << target.first << '\n'; - } - breakTypes.clear(); - } } private: @@ -360,7 +452,8 @@ private: template<typename T, typename S> bool shouldBeEqual(S left, S right, T curr, const char* text) { if (left != right) { - fail() << "" << left << " != " << right << ": " << text << ", on \n" << curr << std::endl; + fail() << "" << left << " != " << right << ": " << text << ", on \n"; + WasmPrinter::printExpression(curr, std::cerr, false, true) << std::endl; valid = false; return false; } @@ -379,7 +472,8 @@ private: template<typename T, typename S> bool shouldBeEqualOrFirstIsUnreachable(S left, S right, T curr, const char* text) { if (left != unreachable && left != right) { - fail() << "" << left << " != " << right << ": " << text << ", on \n" << curr << std::endl; + fail() << "" << left << " != " << right << ": " << text << ", on \n"; + WasmPrinter::printExpression(curr, std::cerr, false, true) << std::endl; valid = false; return false; } @@ -396,7 +490,7 @@ private: return true; } - void validateAlignment(size_t align) { + void validateAlignment(size_t align, WasmType type, Index bytes) { switch (align) { case 1: case 2: @@ -408,6 +502,20 @@ private: break; } } + shouldBeTrue(align <= bytes, align, "alignment must not exceed natural"); + switch (type) { + case i32: + case f32: { + shouldBeTrue(align <= 4, align, "alignment must not exceed natural"); + break; + } + case i64: + case f64: { + shouldBeTrue(align <= 8, align, "alignment must not exceed natural"); + break; + } + default: {} + } } }; |