diff options
author | Thomas Lively <7121787+tlively@users.noreply.github.com> | 2020-03-10 13:43:06 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-03-10 13:43:06 -0700 |
commit | 8f16059d3c29e285d4effed7f0c1f84c1f2f4d9d (patch) | |
tree | faaee424f0b2b77d199c385abe103ea6044cae4e /src | |
parent | 49e31f2034d9532f29704be3039829aa201556a0 (diff) | |
download | binaryen-8f16059d3c29e285d4effed7f0c1f84c1f2f4d9d.tar.gz binaryen-8f16059d3c29e285d4effed7f0c1f84c1f2f4d9d.tar.bz2 binaryen-8f16059d3c29e285d4effed7f0c1f84c1f2f4d9d.zip |
Handle multivalue returns in the interpreter (#2684)
Updates the interpreter to properly flow vectors of values, including
at function boundaries. Adds a small spec test for multivalue return.
Diffstat (limited to 'src')
-rw-r--r-- | src/literal.h | 8 | ||||
-rw-r--r-- | src/passes/Precompute.cpp | 31 | ||||
-rw-r--r-- | src/passes/SimplifyGlobals.cpp | 4 | ||||
-rw-r--r-- | src/shell-interface.h | 14 | ||||
-rw-r--r-- | src/tools/execution-results.h | 31 | ||||
-rw-r--r-- | src/tools/wasm-ctor-eval.cpp | 12 | ||||
-rw-r--r-- | src/tools/wasm-shell.cpp | 30 | ||||
-rw-r--r-- | src/wasm-interpreter.h | 92 | ||||
-rw-r--r-- | src/wasm-type.h | 15 | ||||
-rw-r--r-- | src/wasm.h | 3 | ||||
-rw-r--r-- | src/wasm/literal.cpp | 19 | ||||
-rw-r--r-- | src/wasm/wasm.cpp | 14 |
12 files changed, 169 insertions, 104 deletions
diff --git a/src/literal.h b/src/literal.h index bce578d46..0b87cd189 100644 --- a/src/literal.h +++ b/src/literal.h @@ -23,6 +23,7 @@ #include "compiler-support.h" #include "support/hash.h" #include "support/name.h" +#include "support/small_vector.h" #include "support/utilities.h" #include "wasm-type.h" @@ -179,8 +180,6 @@ public: static void printDouble(std::ostream& o, double d); static void printVec128(std::ostream& o, const std::array<uint8_t, 16>& v); - friend std::ostream& operator<<(std::ostream& o, Literal literal); - Literal countLeadingZeroes() const; Literal countTrailingZeroes() const; Literal popCount() const; @@ -445,6 +444,11 @@ private: Literal avgrUInt(const Literal& other) const; }; +using Literals = SmallVector<Literal, 1>; + +std::ostream& operator<<(std::ostream& o, wasm::Literal literal); +std::ostream& operator<<(std::ostream& o, wasm::Literals literals); + } // namespace wasm namespace std { diff --git a/src/passes/Precompute.cpp b/src/passes/Precompute.cpp index 197b7bc48..176642fd6 100644 --- a/src/passes/Precompute.cpp +++ b/src/passes/Precompute.cpp @@ -192,7 +192,7 @@ struct Precompute // TODO: handle multivalue types return; } - if (flow.getSingleValue().type.isVector()) { + if (flow.getType().hasVector()) { return; } if (flow.breaking()) { @@ -203,26 +203,24 @@ struct Precompute // this expression causes a return. if it's already a return, reuse the // node if (auto* ret = curr->dynCast<Return>()) { - if (flow.getSingleValue().type != Type::none) { + if (flow.values.size() > 0) { // reuse a const value if there is one - if (ret->value) { + if (ret->value && flow.values.size() == 1) { if (auto* value = ret->value->dynCast<Const>()) { value->value = flow.getSingleValue(); value->finalize(); return; } } - ret->value = - Builder(*getModule()).makeConstExpression(flow.getSingleValue()); + ret->value = flow.getConstExpression(*getModule()); } else { ret->value = nullptr; } } else { Builder builder(*getModule()); replaceCurrent(builder.makeReturn( - flow.getSingleValue().type != Type::none - ? builder.makeConstExpression(flow.getSingleValue()) - : nullptr)); + flow.values.size() > 0 ? flow.getConstExpression(*getModule()) + : nullptr)); } return; } @@ -231,9 +229,9 @@ struct Precompute if (auto* br = curr->dynCast<Break>()) { br->name = flow.breakTo; br->condition = nullptr; - if (flow.getSingleValue().type != Type::none) { + if (flow.values.size() > 0) { // reuse a const value if there is one - if (br->value) { + if (br->value && flow.values.size() == 1) { if (auto* value = br->value->dynCast<Const>()) { value->value = flow.getSingleValue(); value->finalize(); @@ -241,8 +239,7 @@ struct Precompute return; } } - br->value = - Builder(*getModule()).makeConstExpression(flow.getSingleValue()); + br->value = flow.getConstExpression(*getModule()); } else { br->value = nullptr; } @@ -251,16 +248,14 @@ struct Precompute Builder builder(*getModule()); replaceCurrent(builder.makeBreak( flow.breakTo, - flow.getSingleValue().type != Type::none - ? builder.makeConstExpression(flow.getSingleValue()) - : nullptr)); + flow.values.size() > 0 ? flow.getConstExpression(*getModule()) + : nullptr)); } return; } // this was precomputed - if (flow.getSingleValue().type.isConcrete()) { - replaceCurrent( - Builder(*getModule()).makeConstExpression(flow.getSingleValue())); + if (flow.getType().isConcrete()) { + replaceCurrent(flow.getConstExpression(*getModule())); worked = true; } else { ExpressionManipulator::nop(curr); diff --git a/src/passes/SimplifyGlobals.cpp b/src/passes/SimplifyGlobals.cpp index aa211b86b..44e3338e6 100644 --- a/src/passes/SimplifyGlobals.cpp +++ b/src/passes/SimplifyGlobals.cpp @@ -109,7 +109,7 @@ struct ConstantGlobalApplier if (auto* set = curr->dynCast<GlobalSet>()) { if (Properties::isConstantExpression(set->value)) { currConstantGlobals[set->name] = - getLiteralFromConstExpression(set->value); + getSingleLiteralFromConstExpression(set->value); } else { currConstantGlobals.erase(set->name); } @@ -253,7 +253,7 @@ struct SimplifyGlobals : public Pass { if (!global->imported()) { if (Properties::isConstantExpression(global->init)) { constantGlobals[global->name] = - getLiteralFromConstExpression(global->init); + getSingleLiteralFromConstExpression(global->init); } else if (auto* get = global->init->dynCast<GlobalGet>()) { auto iter = constantGlobals.find(get->name); if (iter != constantGlobals.end()) { diff --git a/src/shell-interface.h b/src/shell-interface.h index 03b190626..de4fe357a 100644 --- a/src/shell-interface.h +++ b/src/shell-interface.h @@ -134,12 +134,12 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface { } } - Literal callImport(Function* import, LiteralList& arguments) override { + Literals callImport(Function* import, LiteralList& arguments) override { if (import->module == SPECTEST && import->base.startsWith(PRINT)) { for (auto argument : arguments) { std::cout << argument << " : " << argument.type << '\n'; } - return Literal(); + return {}; } else if (import->module == ENV && import->base == EXIT) { // XXX hack for torture tests std::cout << "exit()\n"; @@ -149,11 +149,11 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface { << import->name.str; } - Literal callTable(Index index, - Signature sig, - LiteralList& arguments, - Type results, - ModuleInstance& instance) override { + Literals callTable(Index index, + Signature sig, + LiteralList& arguments, + Type results, + ModuleInstance& instance) override { if (index >= table.size()) { trap("callTable overflow"); } diff --git a/src/tools/execution-results.h b/src/tools/execution-results.h index 7787dba25..8e7371c5f 100644 --- a/src/tools/execution-results.h +++ b/src/tools/execution-results.h @@ -32,7 +32,7 @@ struct LoggingExternalInterface : public ShellExternalInterface { LoggingExternalInterface(Loggings& loggings) : loggings(loggings) {} - Literal callImport(Function* import, LiteralList& arguments) override { + Literals callImport(Function* import, LiteralList& arguments) override { if (import->module == "fuzzing-support") { std::cout << "[LoggingExternalInterface logging"; loggings.push_back(Literal()); // buffer with a None between calls @@ -42,7 +42,7 @@ struct LoggingExternalInterface : public ShellExternalInterface { } std::cout << "]\n"; } - return Literal(); + return {}; } }; @@ -51,7 +51,7 @@ struct LoggingExternalInterface : public ShellExternalInterface { // we can only get results when there are no imports. we then call each method // that has a result, with some values struct ExecutionResults { - std::map<Name, Literal> results; + std::map<Name, Literals> results; Loggings loggings; // get results of execution @@ -69,18 +69,21 @@ struct ExecutionResults { auto* func = wasm.getFunction(exp->value); if (func->sig.results != Type::none) { // this has a result - Literal ret = run(func, wasm, instance); + Literals ret = run(func, wasm, instance); // We cannot compare funcrefs by name because function names can // change (after duplicate function elimination or roundtripping) // while the function contents are still the same - if (ret.type != Type::funcref) { - results[exp->name] = ret; - // ignore the result if we hit an unreachable and returned no value - if (results[exp->name].type.isConcrete()) { - std::cout << "[fuzz-exec] note result: " << exp->name << " => " - << results[exp->name] << '\n'; + for (Literal& val : ret) { + if (val.type == Type::funcref) { + val = Literal::makeFuncref(Name("funcref")); } } + results[exp->name] = ret; + // ignore the result if we hit an unreachable and returned no value + if (ret.size() > 0) { + std::cout << "[fuzz-exec] note result: " << exp->name << " => " + << ret << '\n'; + } } else { // no result, run it anyhow (it might modify memory etc.) run(func, wasm, instance); @@ -123,18 +126,18 @@ struct ExecutionResults { bool operator!=(ExecutionResults& other) { return !((*this) == other); } - Literal run(Function* func, Module& wasm) { + Literals run(Function* func, Module& wasm) { LoggingExternalInterface interface(loggings); try { ModuleInstance instance(wasm, &interface); return run(func, wasm, instance); } catch (const TrapException&) { // may throw in instance creation (init of offsets) - return Literal(); + return {}; } } - Literal run(Function* func, Module& wasm, ModuleInstance& instance) { + Literals run(Function* func, Module& wasm, ModuleInstance& instance) { try { LiteralList arguments; // init hang support, if present @@ -148,7 +151,7 @@ struct ExecutionResults { } return instance.callFunction(func->name, arguments); } catch (const TrapException&) { - return Literal(); + return {}; } } }; diff --git a/src/tools/wasm-ctor-eval.cpp b/src/tools/wasm-ctor-eval.cpp index 3aef10cf0..2d0ef1ab4 100644 --- a/src/tools/wasm-ctor-eval.cpp +++ b/src/tools/wasm-ctor-eval.cpp @@ -203,7 +203,7 @@ struct CtorEvalExternalInterface : EvallingModuleInstance::ExternalInterface { }); } - Literal callImport(Function* import, LiteralList& arguments) override { + Literals callImport(Function* import, LiteralList& arguments) override { std::string extra; if (import->module == ENV && import->base == "___cxa_atexit") { extra = "\nrecommendation: build with -s NO_EXIT_RUNTIME=1 so that calls " @@ -214,11 +214,11 @@ struct CtorEvalExternalInterface : EvallingModuleInstance::ExternalInterface { extra); } - Literal callTable(Index index, - Signature sig, - LiteralList& arguments, - Type result, - EvallingModuleInstance& instance) override { + Literals callTable(Index index, + Signature sig, + LiteralList& arguments, + Type result, + EvallingModuleInstance& instance) override { // we assume the table is not modified (hmm) // look through the segments, try to find the function for (auto& segment : wasm->table.segments) { diff --git a/src/tools/wasm-shell.cpp b/src/tools/wasm-shell.cpp index 6c9d3f36a..301da6cac 100644 --- a/src/tools/wasm-shell.cpp +++ b/src/tools/wasm-shell.cpp @@ -74,15 +74,15 @@ struct Operation { name = element[i++]->str(); for (size_t j = i; j < element.size(); j++) { Expression* argument = builder.parseExpression(*element[j]); - arguments.push_back(getLiteralFromConstExpression(argument)); + arguments.push_back(getSingleLiteralFromConstExpression(argument)); } } - Literal operate() { + Literals operate() { if (operation == INVOKE) { return instance->callExport(name, arguments); } else if (operation == GET) { - return instance->getExport(name); + return {instance->getExport(name)}; } else { WASM_UNREACHABLE("unknown operation"); } @@ -203,7 +203,7 @@ static void run_asserts(Name moduleName, // an invoke test bool trapped = false; WASM_UNUSED(trapped); - Literal result; + Literals result; try { Operation operation(*curr[1], instance, *builder); result = operation.operate(); @@ -212,21 +212,15 @@ static void run_asserts(Name moduleName, } if (id == ASSERT_RETURN) { assert(!trapped); + Literals expected; if (curr.size() >= 3) { - Literal expected = - getLiteralFromConstExpression(builder->parseExpression(*curr[2])); - std::cerr << "seen " << result << ", expected " << expected << '\n'; - if (expected != result) { - std::cout << "unexpected, should be identical\n"; - abort(); - } - } else { - Literal expected; - std::cerr << "seen " << result << ", expected " << expected << '\n'; - if (expected != result) { - std::cout << "unexpected, should be identical\n"; - abort(); - } + expected = + getLiteralsFromConstExpression(builder->parseExpression(*curr[2])); + } + std::cerr << "seen " << result << ", expected " << expected << '\n'; + if (expected != result) { + std::cout << "unexpected, should be identical\n"; + abort(); } } if (id == ASSERT_TRAP) { diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index 153453f2b..2b8fa479a 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -30,6 +30,7 @@ #include "ir/module-utils.h" #include "support/bits.h" #include "support/safe_integer.h" +#include "wasm-builder.h" #include "wasm-traversal.h" #include "wasm.h" @@ -49,19 +50,42 @@ extern Name WASM, RETURN_FLOW; // in control flow. class Flow { public: - Flow() : values{Literal()} {} - Flow(Literal value) : values{value} {} - Flow(Name breakTo) : values{Literal()}, breakTo(breakTo) {} + Flow() : values() {} + Flow(Literal value) : values{value} { assert(value.type.isConcrete()); } + Flow(Literals&& values) : values(values) {} + Flow(Name breakTo) : values(), breakTo(breakTo) {} - SmallVector<Literal, 1> values; + Literals values; Name breakTo; // if non-null, a break is going on // A helper function for the common case where there is only one value - Literal& getSingleValue() { + const Literal& getSingleValue() { assert(values.size() == 1); return values[0]; } + Type getType() { + std::vector<Type> types; + for (auto& val : values) { + types.push_back(val.type); + } + return Type(types); + } + + Expression* getConstExpression(Module& module) { + assert(values.size() > 0); + Builder builder(module); + if (values.size() == 1) { + return builder.makeConstExpression(getSingleValue()); + } else { + std::vector<Expression*> consts; + for (auto& val : values) { + consts.push_back(builder.makeConstExpression(val)); + } + return builder.makeTupleMake(std::move(consts)); + } + } + bool breaking() { return breakTo.is(); } void clearIf(Name target) { @@ -167,16 +191,18 @@ public: trap("interpreter recursion limit"); } auto ret = OverriddenVisitor<SubType, Flow>::visit(curr); - if (!ret.breaking() && - (curr->type.isConcrete() || ret.getSingleValue().type.isConcrete())) { + if (!ret.breaking()) { + Type type = ret.getType(); + if (type.isConcrete() || curr->type.isConcrete()) { #if 1 // def WASM_INTERPRETER_DEBUG - if (!Type::isSubType(ret.getSingleValue().type, curr->type)) { - std::cerr << "expected " << curr->type << ", seeing " - << ret.getSingleValue().type << " from\n" - << curr << '\n'; - } + if (!Type::isSubType(type, curr->type)) { + std::cerr << "expected " << curr->type << ", seeing " << type + << " from\n" + << curr << '\n'; + } #endif - assert(Type::isSubType(ret.getSingleValue().type, curr->type)); + assert(Type::isSubType(type, curr->type)); + } } depth--; return ret; @@ -274,14 +300,13 @@ public: Flow visitSwitch(Switch* curr) { NOTE_ENTER("Switch"); Flow flow; - Literal value; + Literals values; if (curr->value) { flow = visit(curr->value); if (flow.breaking()) { return flow; } - value = flow.getSingleValue(); - NOTE_EVAL1(value); + values = flow.values; } flow = visit(curr->condition); if (flow.breaking()) { @@ -293,7 +318,7 @@ public: target = curr->targets[(size_t)index]; } flow.breakTo = target; - flow.getSingleValue() = value; + flow.values = values; return flow; } @@ -1136,6 +1161,7 @@ public: return flow; } for (auto arg : arguments) { + assert(arg.type.isConcrete()); flow.values.push_back(arg); } return flow; @@ -1234,12 +1260,12 @@ public: struct ExternalInterface { virtual void init(Module& wasm, SubType& instance) {} virtual void importGlobals(GlobalManager& globals, Module& wasm) = 0; - virtual Literal callImport(Function* import, LiteralList& arguments) = 0; - virtual Literal callTable(Index index, - Signature sig, - LiteralList& arguments, - Type result, - SubType& instance) = 0; + virtual Literals callImport(Function* import, LiteralList& arguments) = 0; + virtual Literals callTable(Index index, + Signature sig, + LiteralList& arguments, + Type result, + SubType& instance) = 0; virtual void growMemory(Address oldSize, Address newSize) = 0; virtual void trap(const char* why) = 0; @@ -1424,7 +1450,7 @@ public: } // call an exported function - Literal callExport(Name name, const LiteralList& arguments) { + Literals callExport(Name name, const LiteralList& arguments) { Export* export_ = wasm.getExportOrNull(name); if (!export_) { externalInterface->trap("callExport not found"); @@ -1578,9 +1604,9 @@ private: auto* func = instance.wasm.getFunction(curr->target); Flow ret; if (func->imported()) { - ret = instance.externalInterface->callImport(func, arguments); + ret.values = instance.externalInterface->callImport(func, arguments); } else { - ret = instance.callFunctionInternal(curr->target, arguments); + ret.values = instance.callFunctionInternal(curr->target, arguments); } #ifdef WASM_INTERPRETER_DEBUG std::cout << "(returned to " << scope.function->name << ")\n"; @@ -2095,7 +2121,7 @@ private: public: // Call a function, starting an invocation. - Literal callFunction(Name name, const LiteralList& arguments) { + Literals callFunction(Name name, const LiteralList& arguments) { // if the last call ended in a jump up the stack, it might have left stuff // for us to clean up here callDepth = 0; @@ -2105,7 +2131,7 @@ public: // Internal function call. Must be public so that callTable implementations // can use it (refactor?) - Literal callFunctionInternal(Name name, const LiteralList& arguments) { + Literals callFunctionInternal(Name name, const LiteralList& arguments) { if (callDepth > maxDepth) { externalInterface->trap("stack limit"); } @@ -2129,12 +2155,12 @@ public: RuntimeExpressionRunner(*this, scope, maxDepth).visit(function->body); // cannot still be breaking, it means we missed our stop assert(!flow.breaking() || flow.breakTo == RETURN_FLOW); - Literal ret = flow.getSingleValue(); - if (!Type::isSubType(ret.type, function->sig.results)) { - std::cerr << "calling " << function->name << " resulted in " << ret + auto type = flow.getType(); + if (!Type::isSubType(type, function->sig.results)) { + std::cerr << "calling " << function->name << " resulted in " << type << " but the function type is " << function->sig.results << '\n'; - WASM_UNREACHABLE("unexpect result type"); + WASM_UNREACHABLE("unexpected result type"); } // may decrease more than one, if we jumped up the stack callDepth = previousCallDepth; @@ -2145,7 +2171,7 @@ public: #ifdef WASM_INTERPRETER_DEBUG std::cout << "exiting " << function->name << " with " << ret << '\n'; #endif - return ret; + return flow.values; } protected: diff --git a/src/wasm-type.h b/src/wasm-type.h index 0e3907af8..e62c59016 100644 --- a/src/wasm-type.h +++ b/src/wasm-type.h @@ -72,6 +72,21 @@ public: constexpr bool isVector() const { return id == v128; }; constexpr bool isNumber() const { return id >= i32 && id <= v128; } constexpr bool isRef() const { return id >= funcref && id <= exnref; } + +private: + template<bool (Type::*pred)() const> bool hasPredicate() { + for (auto t : expand()) { + if ((t.*pred)()) { + return true; + } + } + return false; + } + +public: + bool hasVector() { return hasPredicate<&Type::isVector>(); } + bool hasRef() { return hasPredicate<&Type::isRef>(); } + constexpr uint32_t getID() const { return id; } constexpr ValueType getSingle() const { assert(!isMulti() && "Unexpected multivalue type"); diff --git a/src/wasm.h b/src/wasm.h index 507ffc047..a89353de6 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -574,7 +574,8 @@ public: const char* getExpressionName(Expression* curr); -Literal getLiteralFromConstExpression(Expression* curr); +Literal getSingleLiteralFromConstExpression(Expression* curr); +Literals getLiteralsFromConstExpression(Expression* curr); typedef ArenaVector<Expression*> ExpressionList; diff --git a/src/wasm/literal.cpp b/src/wasm/literal.cpp index d5333847d..1cf867cdd 100644 --- a/src/wasm/literal.cpp +++ b/src/wasm/literal.cpp @@ -271,10 +271,10 @@ std::ostream& operator<<(std::ostream& o, Literal literal) { o << "?"; break; case Type::i32: - o << literal.i32; + o << literal.geti32(); break; case Type::i64: - o << literal.i64; + o << literal.geti64(); break; case Type::f32: literal.printFloat(o, literal.getf32()); @@ -301,6 +301,21 @@ std::ostream& operator<<(std::ostream& o, Literal literal) { return o; } +std::ostream& operator<<(std::ostream& o, wasm::Literals literals) { + if (literals.size() == 1) { + return o << literals[0]; + } else { + o << '('; + if (literals.size() > 0) { + o << literals[0]; + } + for (size_t i = 1; i < literals.size(); ++i) { + o << ", " << literals[i]; + } + return o << ')'; + } +} + Literal Literal::countLeadingZeroes() const { if (type == Type::i32) { return Literal((int32_t)CountLeadingZeroes(i32)); diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index 64bc5615f..4e4186d08 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp @@ -198,7 +198,7 @@ const char* getExpressionName(Expression* curr) { WASM_UNREACHABLE("invalid expr id"); } -Literal getLiteralFromConstExpression(Expression* curr) { +Literal getSingleLiteralFromConstExpression(Expression* curr) { if (auto* c = curr->dynCast<Const>()) { return c->value; } else if (curr->is<RefNull>()) { @@ -210,6 +210,18 @@ Literal getLiteralFromConstExpression(Expression* curr) { } } +Literals getLiteralsFromConstExpression(Expression* curr) { + if (auto* t = curr->dynCast<TupleMake>()) { + Literals values; + for (auto* operand : t->operands) { + values.push_back(getSingleLiteralFromConstExpression(operand)); + } + return values; + } else { + return {getSingleLiteralFromConstExpression(curr)}; + } +} + // core AST type checking struct TypeSeeker : public PostWalker<TypeSeeker> { |