From 0b6a1f5aa243d8542c086d4274e43bf411d63d59 Mon Sep 17 00:00:00 2001 From: Alon Zakai Date: Wed, 30 May 2018 16:51:12 -0700 Subject: Optimize validation of many nested blocks (#1576) On the testcase from https://github.com/tweag/asterius/issues/19#issuecomment-393052653 this makes us almost 3x faster, and use 25% less memory. The main improvement here is to simplify and optimize the data structures the validator uses to validate br targets: use unordered maps, and use one less of them. Also some speedups from using that map more effectively (use of iterators to avoid multiple lookups). Also move the duplicate-node checks to the internal IR validation section, which makes more sense anyhow (it's not wasm validation, it's internal IR validation, which like the check for stale internal types, we do only if debugging). --- src/wasm/wasm-validator.cpp | 93 ++++++++++++++++++++++++--------------------- 1 file changed, 49 insertions(+), 44 deletions(-) (limited to 'src/wasm/wasm-validator.cpp') diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index 7246a15a6..9b5384efe 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -176,20 +176,27 @@ struct FunctionValidator : public WalkerPass> { FunctionValidator(ValidationInfo* info) : info(*info) {} struct BreakInfo { + enum { + UnsetArity = Index(-1), + PoisonArity = Index(-2) + }; + Type type; Index arity; - BreakInfo() {} + BreakInfo() : arity(UnsetArity) {} BreakInfo(Type type, Index arity) : type(type), arity(arity) {} + + bool hasBeenSet() { + // Compare to the impossible value. + return arity != UnsetArity; + } }; - std::map breakTargets; - std::map breakInfos; + std::unordered_map breakInfos; Type returnType = unreachable; // type used in returns - std::set labelNames; // Binaryen IR requires that label names must be unique - IR generators must ensure that - - std::unordered_set seenExpressions; // expressions must not appear twice + std::unordered_set labelNames; // Binaryen IR requires that label names must be unique - IR generators must ensure that void noteLabelName(Name name); @@ -198,14 +205,14 @@ public: static void visitPreBlock(FunctionValidator* self, Expression** currp) { auto* curr = (*currp)->cast(); - if (curr->name.is()) self->breakTargets[curr->name] = curr; + if (curr->name.is()) self->breakInfos[curr->name]; } void visitBlock(Block* curr); static void visitPreLoop(FunctionValidator* self, Expression** currp) { auto* curr = (*currp)->cast(); - if (curr->name.is()) self->breakTargets[curr->name] = curr; + if (curr->name.is()) self->breakInfos[curr->name]; } void visitLoop(Loop* curr); @@ -285,16 +292,19 @@ private: void FunctionValidator::noteLabelName(Name name) { if (!name.is()) return; - shouldBeTrue(labelNames.find(name) == labelNames.end(), name, "names in Binaryen IR must be unique - IR generators must ensure that"); - labelNames.insert(name); + bool inserted; + std::tie(std::ignore, inserted) = labelNames.insert(name); + shouldBeTrue(inserted, name, "names in Binaryen IR must be unique - IR generators must ensure that"); } void FunctionValidator::visitBlock(Block* curr) { // if we are break'ed to, then the value must be right for us if (curr->name.is()) { noteLabelName(curr->name); - if (breakInfos.count(curr) > 0) { - auto& info = breakInfos[curr]; + auto iter = breakInfos.find(curr->name); + assert(iter != breakInfos.end()); // we set it ourselves + auto& info = iter->second; + if (info.hasBeenSet()) { if (isConcreteType(curr->type)) { shouldBeTrue(info.arity != 0, curr, "break arities must be > 0 if block has a value"); } else { @@ -307,7 +317,7 @@ void FunctionValidator::visitBlock(Block* curr) { if (isConcreteType(curr->type) && info.arity && info.type != unreachable) { shouldBeEqual(curr->type, info.type, curr, "block+breaks must have right type if breaks have arity"); } - shouldBeTrue(info.arity != Index(-1), curr, "break arities must match"); + shouldBeTrue(info.arity != BreakInfo::PoisonArity, curr, "break arities must match"); if (curr->list.size() > 0) { auto last = curr->list.back()->type; if (isConcreteType(last) && info.type != unreachable) { @@ -318,7 +328,7 @@ void FunctionValidator::visitBlock(Block* curr) { } } } - breakTargets.erase(curr->name); + breakInfos.erase(iter); } if (curr->list.size() > 1) { for (Index i = 0; i < curr->list.size() - 1; i++) { @@ -347,11 +357,13 @@ void FunctionValidator::visitBlock(Block* curr) { void FunctionValidator::visitLoop(Loop* curr) { if (curr->name.is()) { noteLabelName(curr->name); - breakTargets.erase(curr->name); - if (breakInfos.count(curr) > 0) { - auto& info = breakInfos[curr]; + auto iter = breakInfos.find(curr->name); + assert(iter != breakInfos.end()); // we set it ourselves + auto& info = iter->second; + if (info.hasBeenSet()) { shouldBeEqual(info.arity, Index(0), curr, "breaks to a loop cannot pass a value"); } + breakInfos.erase(iter); } if (curr->type == none) { shouldBeFalse(isConcreteType(curr->body->type), curr, "bad body for a loop that has no value"); @@ -394,12 +406,12 @@ void FunctionValidator::noteBreak(Name name, Expression* value, Expression* curr shouldBeUnequal(valueType, none, curr, "breaks must have a valid value"); arity = 1; } - if (!shouldBeTrue(breakTargets.count(name) > 0, curr, "all break targets must be valid")) return; - auto* target = breakTargets[name]; - if (breakInfos.count(target) == 0) { - breakInfos[target] = BreakInfo(valueType, arity); + auto iter = breakInfos.find(name); + if (!shouldBeTrue(iter != breakInfos.end(), curr, "all break targets must be valid")) return; + auto& info = iter->second; + if (!info.hasBeenSet()) { + info = BreakInfo(valueType, arity); } else { - auto& info = breakInfos[target]; if (info.type == unreachable) { info.type = valueType; } else if (valueType != unreachable) { @@ -408,7 +420,7 @@ void FunctionValidator::noteBreak(Name name, Expression* value, Expression* curr } } if (arity != info.arity) { - info.arity = Index(-1); // a poison value + info.arity = BreakInfo::PoisonArity; } } } @@ -810,7 +822,7 @@ void FunctionValidator::visitFunction(Function* curr) { if (returnType != unreachable) { shouldBeEqual(curr->result, returnType, curr->body, "function result must match, if function has returns"); } - shouldBeTrue(breakTargets.empty(), curr->body, "all named break targets must exist"); + shouldBeTrue(breakInfos.empty(), curr->body, "all named break targets must exist"); returnType = unreachable; labelNames.clear(); // if function has a named type, it must match up with the function's params and result @@ -819,24 +831,6 @@ void FunctionValidator::visitFunction(Function* curr) { shouldBeTrue(ft->params == curr->params, curr->name, "function params must match its declared type"); shouldBeTrue(ft->result == curr->result, curr->name, "function result must match its declared type"); } - // expressions must not be seen more than once - struct Walker : public PostWalker> { - std::unordered_set& seen; - std::vector dupes; - - Walker(std::unordered_set& seen) : seen(seen) {} - - void visitExpression(Expression* curr) { - bool inserted; - std::tie(std::ignore, inserted) = seen.insert(curr); - if (!inserted) dupes.push_back(curr); - } - }; - Walker walker(seenExpressions); - walker.walk(curr->body); - for (auto* bad : walker.dupes) { - info.fail("expression seen more than once in the tree", bad, getFunction()); - } } static bool checkOffset(Expression* curr, Address add, Address max) { @@ -890,9 +884,12 @@ static void validateBinaryenIR(Module& wasm, ValidationInfo& info) { struct BinaryenIRValidator : public PostWalker> { ValidationInfo& info; + std::unordered_set seen; + BinaryenIRValidator(ValidationInfo& info) : info(info) {} void visitExpression(Expression* curr) { + auto scope = getFunction() ? getFunction()->name : Name("(global scope)"); // check if a node type is 'stale', i.e., we forgot to finalize() the node. auto oldType = curr->type; ReFinalizeNode().visit(curr); @@ -907,11 +904,19 @@ static void validateBinaryenIR(Module& wasm, ValidationInfo& info) { // ok for it to be either i32 or unreachable. if (!(isConcreteType(oldType) && newType == unreachable)) { std::ostringstream ss; - ss << "stale type found in " << (getFunction() ? getFunction()->name : Name("(global scope)")) << " on " << curr << "\n(marked as " << printType(oldType) << ", should be " << printType(newType) << ")\n"; + ss << "stale type found in " << scope << " on " << curr << "\n(marked as " << printType(oldType) << ", should be " << printType(newType) << ")\n"; info.fail(ss.str(), curr, getFunction()); } curr->type = oldType; } + // check if a node is a duplicate - expressions must not be seen more than once + bool inserted; + std::tie(std::ignore, inserted) = seen.insert(curr); + if (!inserted) { + std::ostringstream ss; + ss << "expression seen more than once in the tree in " << scope << " on " << curr << '\n'; + info.fail(ss.str(), curr, getFunction()); + } } }; BinaryenIRValidator binaryenIRValidator(info); @@ -952,7 +957,7 @@ static void validateExports(Module& module, ValidationInfo& info) { } } } - std::set exportNames; + std::unordered_set exportNames; for (auto& exp : module.exports) { Name name = exp->value; if (exp->kind == ExternalKind::Function) { -- cgit v1.2.3