diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/ir/ExpressionAnalyzer.cpp | 6 | ||||
-rw-r--r-- | src/ir/ExpressionManipulator.cpp | 5 | ||||
-rw-r--r-- | src/ir/effects.h | 10 | ||||
-rw-r--r-- | src/passes/Asyncify.cpp | 6 | ||||
-rw-r--r-- | src/passes/DeadArgumentElimination.cpp | 54 | ||||
-rw-r--r-- | src/passes/DeadCodeElimination.cpp | 10 | ||||
-rw-r--r-- | src/passes/Directize.cpp | 3 | ||||
-rw-r--r-- | src/passes/I64ToI32Lowering.cpp | 14 | ||||
-rw-r--r-- | src/passes/Inlining.cpp | 87 | ||||
-rw-r--r-- | src/tools/fuzzing.h | 23 | ||||
-rw-r--r-- | src/wasm-builder.h | 2 | ||||
-rw-r--r-- | src/wasm-interpreter.h | 24 | ||||
-rw-r--r-- | src/wasm/wasm-validator.cpp | 60 | ||||
-rw-r--r-- | src/wasm/wasm.cpp | 10 | ||||
-rw-r--r-- | src/wasm2js.h | 6 |
15 files changed, 266 insertions, 54 deletions
diff --git a/src/ir/ExpressionAnalyzer.cpp b/src/ir/ExpressionAnalyzer.cpp index 170202888..99c2798b0 100644 --- a/src/ir/ExpressionAnalyzer.cpp +++ b/src/ir/ExpressionAnalyzer.cpp @@ -132,9 +132,13 @@ template<typename T> void visitImmediates(Expression* curr, T& visitor) { } visitor.visitScopeName(curr->default_); } - void visitCall(Call* curr) { visitor.visitNonScopeName(curr->target); } + void visitCall(Call* curr) { + visitor.visitNonScopeName(curr->target); + visitor.visitInt(curr->isReturn); + } void visitCallIndirect(CallIndirect* curr) { visitor.visitNonScopeName(curr->fullType); + visitor.visitInt(curr->isReturn); } void visitLocalGet(LocalGet* curr) { visitor.visitIndex(curr->index); } void visitLocalSet(LocalSet* curr) { visitor.visitIndex(curr->index); } diff --git a/src/ir/ExpressionManipulator.cpp b/src/ir/ExpressionManipulator.cpp index b52a8dd82..783342780 100644 --- a/src/ir/ExpressionManipulator.cpp +++ b/src/ir/ExpressionManipulator.cpp @@ -71,7 +71,8 @@ flexibleCopy(Expression* original, Module& wasm, CustomCopier custom) { copy(curr->value)); } Expression* visitCall(Call* curr) { - auto* ret = builder.makeCall(curr->target, {}, curr->type); + auto* ret = + builder.makeCall(curr->target, {}, curr->type, curr->isReturn); for (Index i = 0; i < curr->operands.size(); i++) { ret->operands.push_back(copy(curr->operands[i])); } @@ -79,7 +80,7 @@ flexibleCopy(Expression* original, Module& wasm, CustomCopier custom) { } Expression* visitCallIndirect(CallIndirect* curr) { auto* ret = builder.makeCallIndirect( - curr->fullType, copy(curr->target), {}, curr->type); + curr->fullType, copy(curr->target), {}, curr->type, curr->isReturn); for (Index i = 0; i < curr->operands.size(); i++) { ret->operands.push_back(copy(curr->operands[i])); } diff --git a/src/ir/effects.h b/src/ir/effects.h index 14cbd6217..dac5b878a 100644 --- a/src/ir/effects.h +++ b/src/ir/effects.h @@ -223,6 +223,9 @@ struct EffectAnalyzer void visitCall(Call* curr) { calls = true; + if (curr->isReturn) { + branches = true; + } if (debugInfo) { // debugInfo call imports must be preserved very strongly, do not // move code around them @@ -230,7 +233,12 @@ struct EffectAnalyzer branches = true; } } - void visitCallIndirect(CallIndirect* curr) { calls = true; } + void visitCallIndirect(CallIndirect* curr) { + calls = true; + if (curr->isReturn) { + branches = true; + } + } void visitLocalGet(LocalGet* curr) { localsRead.insert(curr->index); } void visitLocalSet(LocalSet* curr) { localsWritten.insert(curr->index); } void visitGlobalGet(GlobalGet* curr) { globalsRead.insert(curr->name); } diff --git a/src/passes/Asyncify.cpp b/src/passes/Asyncify.cpp index cb80b90c7..98b4425b1 100644 --- a/src/passes/Asyncify.cpp +++ b/src/passes/Asyncify.cpp @@ -349,6 +349,9 @@ public: } struct Walker : PostWalker<Walker> { void visitCall(Call* curr) { + if (curr->isReturn) { + Fatal() << "tail calls not yet supported in aysncify"; + } auto* target = module->getFunction(curr->target); if (target->imported() && target->module == ASYNCIFY) { // Redirect the imports to the functions we'll add later. @@ -375,6 +378,9 @@ public: info->callsTo.insert(target); } void visitCallIndirect(CallIndirect* curr) { + if (curr->isReturn) { + Fatal() << "tail calls not yet supported in aysncify"; + } if (canIndirectChangeState) { info->canChangeState = true; } diff --git a/src/passes/DeadArgumentElimination.cpp b/src/passes/DeadArgumentElimination.cpp index 789332b93..9ef762f98 100644 --- a/src/passes/DeadArgumentElimination.cpp +++ b/src/passes/DeadArgumentElimination.cpp @@ -57,6 +57,16 @@ struct DAEFunctionInfo { // Map of all calls that are dropped, to their drops' locations (so that // if we can optimize out the drop, we can replace the drop there). std::unordered_map<Call*, Expression**> droppedCalls; + // Whether this function contains any tail calls (including indirect tail + // calls) and the set of functions this function tail calls. Tail-callers and + // tail-callees cannot have their dropped returns removed because of the + // constraint that tail-callees must have the same return type as + // tail-callers. Indirectly tail called functions are already not optimized + // because being in a table inhibits DAE. TODO: Allow the removal of dropped + // returns from tail-callers if their tail-callees can have their returns + // removed as well. + bool hasTailCalls = false; + std::unordered_set<Name> tailCallees; // Whether the function can be called from places that // affect what we can do. For now, any call we don't // see inhibits our optimizations, but TODO: an export @@ -117,6 +127,16 @@ struct DAEScanner if (!getModule()->getFunction(curr->target)->imported()) { info->calls[curr->target].push_back(curr); } + if (curr->isReturn) { + info->hasTailCalls = true; + info->tailCallees.insert(curr->target); + } + } + + void visitCallIndirect(CallIndirect* curr) { + if (curr->isReturn) { + info->hasTailCalls = true; + } } void visitDrop(Drop* curr) { @@ -239,6 +259,7 @@ struct DAE : public Pass { DAEScanner(&infoMap).run(runner, module); // Combine all the info. std::unordered_map<Name, std::vector<Call*>> allCalls; + std::unordered_set<Name> tailCallees; for (auto& pair : infoMap) { auto& info = pair.second; for (auto& pair : info.calls) { @@ -247,6 +268,9 @@ struct DAE : public Pass { auto& allCallsToName = allCalls[name]; allCallsToName.insert(allCallsToName.end(), calls.begin(), calls.end()); } + for (auto& callee : info.tailCallees) { + tailCallees.insert(callee); + } for (auto& pair : info.droppedCalls) { allDroppedCalls[pair.first] = pair.second; } @@ -314,14 +338,11 @@ struct DAE : public Pass { // Great, it's not used. Check if none of the calls has a param with // side effects, as that would prevent us removing them (flattening // should have been done earlier). - bool canRemove = true; - for (auto* call : calls) { - auto* operand = call->operands[i]; - if (EffectAnalyzer(runner->options, operand).hasSideEffects()) { - canRemove = false; - break; - } - } + bool canRemove = + std::none_of(calls.begin(), calls.end(), [&](Call* call) { + auto* operand = call->operands[i]; + return EffectAnalyzer(runner->options, operand).hasSideEffects(); + }); if (canRemove) { // Wonderful, nothing stands in our way! Do it. // TODO: parallelize this? @@ -348,18 +369,21 @@ struct DAE : public Pass { if (infoMap[name].hasUnseenCalls) { continue; } + if (infoMap[name].hasTailCalls) { + continue; + } + if (tailCallees.find(name) != tailCallees.end()) { + continue; + } auto iter = allCalls.find(name); if (iter == allCalls.end()) { continue; } auto& calls = iter->second; - bool allDropped = true; - for (auto* call : calls) { - if (!allDroppedCalls.count(call)) { - allDropped = false; - break; - } - } + bool allDropped = + std::all_of(calls.begin(), calls.end(), [&](Call* call) { + return allDroppedCalls.count(call); + }); if (!allDropped) { continue; } diff --git a/src/passes/DeadCodeElimination.cpp b/src/passes/DeadCodeElimination.cpp index 6e6f2c2b9..61d16afc0 100644 --- a/src/passes/DeadCodeElimination.cpp +++ b/src/passes/DeadCodeElimination.cpp @@ -365,7 +365,12 @@ struct DeadCodeElimination return curr; } - void visitCall(Call* curr) { handleCall(curr); } + void visitCall(Call* curr) { + handleCall(curr); + if (curr->isReturn) { + reachable = false; + } + } void visitCallIndirect(CallIndirect* curr) { if (handleCall(curr) != curr) { @@ -380,6 +385,9 @@ struct DeadCodeElimination block->finalize(curr->type); replaceCurrent(block); } + if (curr->isReturn) { + reachable = false; + } } // Append the reachable operands of the current node to a block, and replace diff --git a/src/passes/Directize.cpp b/src/passes/Directize.cpp index f6969cce5..663b8257b 100644 --- a/src/passes/Directize.cpp +++ b/src/passes/Directize.cpp @@ -65,7 +65,8 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> { } // Everything looks good! replaceCurrent( - Builder(*getModule()).makeCall(name, curr->operands, curr->type)); + Builder(*getModule()) + .makeCall(name, curr->operands, curr->type, curr->isReturn)); } } diff --git a/src/passes/I64ToI32Lowering.cpp b/src/passes/I64ToI32Lowering.cpp index e893b6296..5d7d5998c 100644 --- a/src/passes/I64ToI32Lowering.cpp +++ b/src/passes/I64ToI32Lowering.cpp @@ -262,9 +262,14 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> { return call; } void visitCall(Call* curr) { + if (curr->isReturn && + getModule()->getFunction(curr->target)->result == i64) { + Fatal() + << "i64 to i32 lowering of return_call values not yet implemented"; + } auto* fixedCall = visitGenericCall<Call>( curr, [&](std::vector<Expression*>& args, Type ty) { - return builder->makeCall(curr->target, args, ty); + return builder->makeCall(curr->target, args, ty, curr->isReturn); }); // If this was to an import, we need to call the legal version. This assumes // that legalize-js-interface has been run before. @@ -275,10 +280,15 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> { } void visitCallIndirect(CallIndirect* curr) { + if (curr->isReturn && + getModule()->getFunctionType(curr->fullType)->result == i64) { + Fatal() + << "i64 to i32 lowering of return_call values not yet implemented"; + } visitGenericCall<CallIndirect>( curr, [&](std::vector<Expression*>& args, Type ty) { return builder->makeCallIndirect( - curr->fullType, curr->target, args, ty); + curr->fullType, curr->target, args, ty, curr->isReturn); }); } diff --git a/src/passes/Inlining.cpp b/src/passes/Inlining.cpp index a625464cd..9ba01c4c1 100644 --- a/src/passes/Inlining.cpp +++ b/src/passes/Inlining.cpp @@ -146,7 +146,17 @@ struct Planner : public WalkerPass<PostWalker<Planner>> { // plan to inline if we know this is valid to inline, and if the call is // actually performed - if it is dead code, it's pointless to inline. // we also cannot inline ourselves. - if (state->worthInlining.count(curr->target) && curr->type != unreachable && + bool isUnreachable; + if (curr->isReturn) { + // Tail calls are only actually unreachable if an argument is + isUnreachable = + std::any_of(curr->operands.begin(), + curr->operands.end(), + [](Expression* op) { return op->type == unreachable; }); + } else { + isUnreachable = curr->type == unreachable; + } + if (state->worthInlining.count(curr->target) && !isUnreachable && curr->target != getFunction()->name) { // nest the call in a block. that way the location of the pointer to the // call will not change even if we inline multiple times into the same @@ -164,32 +174,69 @@ private: InliningState* state; }; +struct Updater : public PostWalker<Updater> { + Module* module; + std::map<Index, Index> localMapping; + Name returnName; + Builder* builder; + void visitReturn(Return* curr) { + replaceCurrent(builder->makeBreak(returnName, curr->value)); + } + // Return calls in inlined functions should only break out of the scope of + // the inlined code, not the entire function they are being inlined into. To + // achieve this, make the call a non-return call and add a break. This does + // not cause unbounded stack growth because inlining and return calling both + // avoid creating a new stack frame. + template<typename T> void handleReturnCall(T* curr, Type targetType) { + curr->isReturn = false; + curr->type = targetType; + if (isConcreteType(targetType)) { + replaceCurrent(builder->makeBreak(returnName, curr)); + } else { + replaceCurrent(builder->blockify(curr, builder->makeBreak(returnName))); + } + } + void visitCall(Call* curr) { + if (curr->isReturn) { + handleReturnCall(curr, module->getFunction(curr->target)->result); + } + } + void visitCallIndirect(CallIndirect* curr) { + if (curr->isReturn) { + handleReturnCall(curr, module->getFunctionType(curr->fullType)->result); + } + } + void visitLocalGet(LocalGet* curr) { + curr->index = localMapping[curr->index]; + } + void visitLocalSet(LocalSet* curr) { + curr->index = localMapping[curr->index]; + } +}; + // Core inlining logic. Modifies the outside function (adding locals as // needed), and returns the inlined code. static Expression* doInlining(Module* module, Function* into, InliningAction& action) { Function* from = action.contents; auto* call = (*action.callSite)->cast<Call>(); + // Works for return_call, too + Type retType = module->getFunction(call->target)->result; Builder builder(*module); - auto* block = Builder(*module).makeBlock(); + auto* block = builder.makeBlock(); block->name = Name(std::string("__inlined_func$") + from->name.str); - *action.callSite = block; - // Prepare to update the inlined code's locals and other things. - struct Updater : public PostWalker<Updater> { - std::map<Index, Index> localMapping; - Name returnName; - Builder* builder; - - void visitReturn(Return* curr) { - replaceCurrent(builder->makeBreak(returnName, curr->value)); + if (call->isReturn) { + if (isConcreteType(retType)) { + *action.callSite = builder.makeReturn(block); + } else { + *action.callSite = builder.makeSequence(block, builder.makeReturn()); } - void visitLocalGet(LocalGet* curr) { - curr->index = localMapping[curr->index]; - } - void visitLocalSet(LocalSet* curr) { - curr->index = localMapping[curr->index]; - } - } updater; + } else { + *action.callSite = block; + } + // Prepare to update the inlined code's locals and other things. + Updater updater; + updater.module = module; updater.returnName = block->name; updater.builder = &builder; // Set up a locals mapping @@ -215,12 +262,12 @@ doInlining(Module* module, Function* into, InliningAction& action) { } updater.walk(contents); block->list.push_back(contents); - block->type = call->type; + block->type = retType; // If the function returned a value, we just set the block containing the // inlined code to have that type. or, if the function was void and // contained void, that is fine too. a bad case is a void function in which // we have unreachable code, so we would be replacing a void call with an - // unreachable; we need to handle + // unreachable. if (contents->type == unreachable && block->type == none) { // Make the block reachable by adding a break to it block->list.push_back(builder.makeBreak(block->name)); diff --git a/src/tools/fuzzing.h b/src/tools/fuzzing.h index 1d68219e2..6a947d80c 100644 --- a/src/tools/fuzzing.h +++ b/src/tools/fuzzing.h @@ -1200,12 +1200,15 @@ private: Expression* makeCall(Type type) { // seems ok, go on int tries = TRIES; + bool isReturn; while (tries-- > 0) { Function* target = func; if (!wasm.functions.empty() && !oneIn(wasm.functions.size())) { target = vectorPick(wasm.functions).get(); } - if (target->result != type) { + isReturn = type == unreachable && wasm.features.hasTailCall() && + func->result == target->result; + if (target->result != type && !isReturn) { continue; } // we found one! @@ -1213,7 +1216,7 @@ private: for (auto argType : target->params) { args.push_back(make(argType)); } - return builder.makeCall(target->name, args, type); + return builder.makeCall(target->name, args, type, isReturn); } // we failed to find something return make(type); @@ -1227,11 +1230,14 @@ private: // look for a call target with the right type Index start = upTo(data.size()); Index i = start; - Function* func; + Function* targetFn; + bool isReturn; while (1) { // TODO: handle unreachable - func = wasm.getFunction(data[i]); - if (func->result == type) { + targetFn = wasm.getFunction(data[i]); + isReturn = type == unreachable && wasm.features.hasTailCall() && + func->result == targetFn->result; + if (targetFn->result == type || isReturn) { break; } i++; @@ -1251,11 +1257,12 @@ private: target = make(i32); } std::vector<Expression*> args; - for (auto type : func->params) { + for (auto type : targetFn->params) { args.push_back(make(type)); } - func->type = ensureFunctionType(getSig(func), &wasm)->name; - return builder.makeCallIndirect(func->type, target, args, func->result); + targetFn->type = ensureFunctionType(getSig(targetFn), &wasm)->name; + return builder.makeCallIndirect( + targetFn->type, target, args, targetFn->result, isReturn); } Expression* makeLocalGet(Type type) { diff --git a/src/wasm-builder.h b/src/wasm-builder.h index 69fcef3f1..923effb76 100644 --- a/src/wasm-builder.h +++ b/src/wasm-builder.h @@ -207,6 +207,7 @@ public: call->target = target; call->operands.set(args); call->isReturn = isReturn; + call->finalize(); return call; } CallIndirect* makeCallIndirect(FunctionType* type, @@ -226,6 +227,7 @@ public: call->target = target; call->operands.set(args); call->isReturn = isReturn; + call->finalize(); return call; } // FunctionType diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index 42e904fd3..f683f9b38 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -1435,6 +1435,15 @@ private: #ifdef WASM_INTERPRETER_DEBUG std::cout << "(returned to " << scope.function->name << ")\n"; #endif + // TODO: make this a proper tail call (return first) + if (curr->isReturn) { + Const c; + c.value = ret.value; + c.finalize(); + Return return_; + return_.value = &c; + return this->visit(&return_); + } return ret; } Flow visitCallIndirect(CallIndirect* curr) { @@ -1449,8 +1458,19 @@ private: return target; } Index index = target.value.geti32(); - return instance.externalInterface->callTable( - index, arguments, curr->type, *instance.self()); + Type type = curr->isReturn ? scope.function->result : curr->type; + Flow ret = instance.externalInterface->callTable( + index, arguments, type, *instance.self()); + // TODO: make this a proper tail call (return first) + if (curr->isReturn) { + Const c; + c.value = ret.value; + c.finalize(); + Return return_; + return_.value = &c; + return this->visit(&return_); + } + return ret; } Flow visitLocalGet(LocalGet* curr) { diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index 8098116be..372c5bf6d 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -549,6 +549,10 @@ void FunctionValidator::noteBreak(Name name, } void FunctionValidator::visitBreak(Break* curr) { noteBreak(curr->name, curr->value, curr); + if (curr->value) { + shouldBeTrue( + curr->value->type != none, curr, "break value must not have none type"); + } if (curr->condition) { shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, @@ -593,6 +597,33 @@ void FunctionValidator::visitCall(Call* curr) { getStream() << "(on argument " << i << ")\n"; } } + if (curr->isReturn) { + shouldBeEqual(curr->type, + unreachable, + curr, + "return_call should have unreachable type"); + shouldBeEqual( + getFunction()->result, + target->result, + curr, + "return_call callee return type must match caller return type"); + } else { + if (curr->type == unreachable) { + bool hasUnreachableOperand = + std::any_of(curr->operands.begin(), + curr->operands.end(), + [](Expression* op) { return op->type == unreachable; }); + shouldBeTrue( + hasUnreachableOperand, + curr, + "calls may only be unreachable if they have unreachable operands"); + } else { + shouldBeEqual(curr->type, + target->result, + curr, + "call type must match callee return type"); + } + } } void FunctionValidator::visitCallIndirect(CallIndirect* curr) { @@ -622,6 +653,35 @@ void FunctionValidator::visitCallIndirect(CallIndirect* curr) { getStream() << "(on argument " << i << ")\n"; } } + if (curr->isReturn) { + shouldBeEqual(curr->type, + unreachable, + curr, + "return_call_indirect should have unreachable type"); + shouldBeEqual( + getFunction()->result, + type->result, + curr, + "return_call_indirect callee return type must match caller return type"); + } else { + if (curr->type == unreachable) { + if (curr->target->type != unreachable) { + bool hasUnreachableOperand = + std::any_of(curr->operands.begin(), + curr->operands.end(), + [](Expression* op) { return op->type == unreachable; }); + shouldBeTrue(hasUnreachableOperand, + curr, + "call_indirects may only be unreachable if they have " + "unreachable operands"); + } + } else { + shouldBeEqual(curr->type, + type->result, + curr, + "call_indirect type must match callee return type"); + } + } } void FunctionValidator::visitConst(Const* curr) { diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index d0a916072..04abbcd9f 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp @@ -405,10 +405,18 @@ template<typename T> void handleUnreachableOperands(T* curr) { } } -void Call::finalize() { handleUnreachableOperands(this); } +void Call::finalize() { + handleUnreachableOperands(this); + if (isReturn) { + type = unreachable; + } +} void CallIndirect::finalize() { handleUnreachableOperands(this); + if (isReturn) { + type = unreachable; + } if (target->type == unreachable) { type = unreachable; } diff --git a/src/wasm2js.h b/src/wasm2js.h index c5fa54028..b1548c0ca 100644 --- a/src/wasm2js.h +++ b/src/wasm2js.h @@ -1115,6 +1115,9 @@ Ref Wasm2JSBuilder::processFunctionBody(Module* m, } Ref visitCall(Call* curr) { + if (curr->isReturn) { + Fatal() << "tail calls not yet supported in wasm2js"; + } Ref theCall = ValueBuilder::makeCall(fromName(curr->target, NameScope::Top)); // For wasm => wasm calls, we don't need coercions. TODO: even imports @@ -1136,6 +1139,9 @@ Ref Wasm2JSBuilder::processFunctionBody(Module* m, } Ref visitCallIndirect(CallIndirect* curr) { + if (curr->isReturn) { + Fatal() << "tail calls not yet supported in wasm2js"; + } // If the target has effects that interact with the operands, we must // reorder it to the start. bool mustReorder = false; |