diff options
Diffstat (limited to 'src/passes')
-rw-r--r-- | src/passes/ConstHoisting.cpp | 9 | ||||
-rw-r--r-- | src/passes/DeadCodeElimination.cpp | 6 | ||||
-rw-r--r-- | src/passes/Flatten.cpp | 37 | ||||
-rw-r--r-- | src/passes/FuncCastEmulation.cpp | 16 | ||||
-rw-r--r-- | src/passes/Inlining.cpp | 17 | ||||
-rw-r--r-- | src/passes/InstrumentLocals.cpp | 32 | ||||
-rw-r--r-- | src/passes/LegalizeJSInterface.cpp | 33 | ||||
-rw-r--r-- | src/passes/LocalCSE.cpp | 7 | ||||
-rw-r--r-- | src/passes/MergeLocals.cpp | 15 | ||||
-rw-r--r-- | src/passes/OptimizeInstructions.cpp | 6 | ||||
-rw-r--r-- | src/passes/Precompute.cpp | 19 | ||||
-rw-r--r-- | src/passes/Print.cpp | 40 | ||||
-rw-r--r-- | src/passes/RemoveUnusedModuleElements.cpp | 6 | ||||
-rw-r--r-- | src/passes/SimplifyGlobals.cpp | 19 | ||||
-rw-r--r-- | src/passes/SimplifyLocals.cpp | 6 | ||||
-rw-r--r-- | src/passes/opt-utils.h | 13 |
16 files changed, 227 insertions, 54 deletions
diff --git a/src/passes/ConstHoisting.cpp b/src/passes/ConstHoisting.cpp index dbb3853d8..4e8cd9910 100644 --- a/src/passes/ConstHoisting.cpp +++ b/src/passes/ConstHoisting.cpp @@ -91,9 +91,12 @@ private: size = value.type.getByteSize(); break; } - case v128: // v128 not implemented yet - case anyref: // anyref cannot have literals - case exnref: { // exnref cannot have literals + // not implemented yet + case v128: + case funcref: + case anyref: + case nullref: + case exnref: { return false; } case none: diff --git a/src/passes/DeadCodeElimination.cpp b/src/passes/DeadCodeElimination.cpp index be6f92ffa..7d5385a83 100644 --- a/src/passes/DeadCodeElimination.cpp +++ b/src/passes/DeadCodeElimination.cpp @@ -347,6 +347,12 @@ struct DeadCodeElimination DELEGATE(Push); case Expression::Id::PopId: DELEGATE(Pop); + case Expression::Id::RefNullId: + DELEGATE(RefNull); + case Expression::Id::RefIsNullId: + DELEGATE(RefIsNull); + case Expression::Id::RefFuncId: + DELEGATE(RefFunc); case Expression::Id::TryId: DELEGATE(Try); case Expression::Id::ThrowId: diff --git a/src/passes/Flatten.cpp b/src/passes/Flatten.cpp index f5115567b..fda8e3f80 100644 --- a/src/passes/Flatten.cpp +++ b/src/passes/Flatten.cpp @@ -21,6 +21,7 @@ #include <ir/branch-utils.h> #include <ir/effects.h> #include <ir/flat.h> +#include <ir/properties.h> #include <ir/utils.h> #include <pass.h> #include <wasm-builder.h> @@ -61,7 +62,9 @@ struct Flatten std::vector<Expression*> ourPreludes; Builder builder(*getModule()); - if (curr->is<Const>() || curr->is<Nop>() || curr->is<Unreachable>()) { + // Nothing to do for constants, nop, and unreachable + if (Properties::isConstantExpression(curr) || curr->is<Nop>() || + curr->is<Unreachable>()) { return; } @@ -194,8 +197,37 @@ struct Flatten auto type = br->value->type; if (type.isConcrete()) { // we are sending a value. use a local instead - Index temp = getTempForBreakTarget(br->name, type); + Type blockType = findBreakTarget(br->name)->type; + Index temp = getTempForBreakTarget(br->name, blockType); ourPreludes.push_back(builder.makeLocalSet(temp, br->value)); + + // br_if leaves a value on the stack if not taken, which later can + // be the last element of the enclosing innermost block and flow + // out. The local we created using 'getTempForBreakTarget' returns + // the return type of the block this branch is targetting, which may + // not be the same with the innermost block's return type. For + // example, + // (block $any (result anyref) + // (block (result nullref) + // (local.tee $0 + // (br_if $any + // (ref.null) + // (i32.const 0) + // ) + // ) + // ) + // ) + // In this case we need two locals to store (ref.null); one with + // anyref type that's for the target block ($label0) and one more + // with nullref type in case for flowing out. Here we create the + // second 'flowing out' local in case two block's types are + // different. + if (type != blockType) { + temp = builder.addVar(getFunction(), type); + ourPreludes.push_back(builder.makeLocalSet( + temp, ExpressionManipulator::copy(br->value, *getModule()))); + } + if (br->condition) { // the value must also flow out ourPreludes.push_back(br); @@ -239,6 +271,7 @@ struct Flatten } } } + // TODO Handle br_on_exn // continue for general handling of everything, control flow or otherwise curr = getCurrent(); // we may have replaced it diff --git a/src/passes/FuncCastEmulation.cpp b/src/passes/FuncCastEmulation.cpp index 729a4a6c3..9d5109a83 100644 --- a/src/passes/FuncCastEmulation.cpp +++ b/src/passes/FuncCastEmulation.cpp @@ -65,11 +65,11 @@ static Expression* toABI(Expression* value, Module* module) { case v128: { WASM_UNREACHABLE("v128 not implemented yet"); } - case anyref: { - WASM_UNREACHABLE("anyref cannot be converted to i64"); - } + case funcref: + case anyref: + case nullref: case exnref: { - WASM_UNREACHABLE("exnref cannot be converted to i64"); + WASM_UNREACHABLE("reference types cannot be converted to i64"); } case none: { // the value is none, but we need a value here @@ -108,11 +108,11 @@ static Expression* fromABI(Expression* value, Type type, Module* module) { case v128: { WASM_UNREACHABLE("v128 not implemented yet"); } - case anyref: { - WASM_UNREACHABLE("anyref cannot be converted from i64"); - } + case funcref: + case anyref: + case nullref: case exnref: { - WASM_UNREACHABLE("exnref cannot be converted from i64"); + WASM_UNREACHABLE("reference types cannot be converted from i64"); } case none: { value = builder.makeDrop(value); diff --git a/src/passes/Inlining.cpp b/src/passes/Inlining.cpp index db1db5971..c43d41e7f 100644 --- a/src/passes/Inlining.cpp +++ b/src/passes/Inlining.cpp @@ -46,13 +46,13 @@ namespace wasm { // Useful into on a function, helping us decide if we can inline it struct FunctionInfo { - std::atomic<Index> calls; + std::atomic<Index> refs; Index size; std::atomic<bool> lightweight; bool usedGlobally; // in a table or export FunctionInfo() { - calls = 0; + refs = 0; size = 0; lightweight = true; usedGlobally = false; @@ -75,7 +75,7 @@ struct FunctionInfo { // FIXME: move this check to be first in this function, since we should // return true if oneCallerInlineMaxSize is bigger than // flexibleInlineMaxSize (which it typically should be). - if (calls == 1 && !usedGlobally && + if (refs == 1 && !usedGlobally && size <= options.inlining.oneCallerInlineMaxSize) { return true; } @@ -108,11 +108,16 @@ struct FunctionInfoScanner void visitCall(Call* curr) { // can't add a new element in parallel assert(infos->count(curr->target) > 0); - (*infos)[curr->target].calls++; + (*infos)[curr->target].refs++; // having a call is not lightweight (*infos)[getFunction()->name].lightweight = false; } + void visitRefFunc(RefFunc* curr) { + assert(infos->count(curr->func) > 0); + (*infos)[curr->func].refs++; + } + void visitFunction(Function* curr) { (*infos)[curr->name].size = Measurer::measure(curr->body); } @@ -374,7 +379,7 @@ struct Inlining : public Pass { doInlining(module, func.get(), action); inlinedUses[inlinedName]++; inlinedInto.insert(func.get()); - assert(inlinedUses[inlinedName] <= infos[inlinedName].calls); + assert(inlinedUses[inlinedName] <= infos[inlinedName].refs); } } // anything we inlined into may now have non-unique label names, fix it up @@ -388,7 +393,7 @@ struct Inlining : public Pass { module->removeFunctions([&](Function* func) { auto name = func->name; auto& info = infos[name]; - return inlinedUses.count(name) && inlinedUses[name] == info.calls && + return inlinedUses.count(name) && inlinedUses[name] == info.refs && !info.usedGlobally; }); // return whether we did any work diff --git a/src/passes/InstrumentLocals.cpp b/src/passes/InstrumentLocals.cpp index 407903219..ae35ec2d1 100644 --- a/src/passes/InstrumentLocals.cpp +++ b/src/passes/InstrumentLocals.cpp @@ -56,14 +56,18 @@ Name get_i32("get_i32"); Name get_i64("get_i64"); Name get_f32("get_f32"); Name get_f64("get_f64"); +Name get_funcref("get_funcref"); Name get_anyref("get_anyref"); +Name get_nullref("get_nullref"); Name get_exnref("get_exnref"); Name set_i32("set_i32"); Name set_i64("set_i64"); Name set_f32("set_f32"); Name set_f64("set_f64"); +Name set_funcref("set_funcref"); Name set_anyref("set_anyref"); +Name set_nullref("set_nullref"); Name set_exnref("set_exnref"); struct InstrumentLocals : public WalkerPass<PostWalker<InstrumentLocals>> { @@ -84,9 +88,15 @@ struct InstrumentLocals : public WalkerPass<PostWalker<InstrumentLocals>> { break; case v128: assert(false && "v128 not implemented yet"); + case funcref: + import = get_funcref; + break; case anyref: import = get_anyref; break; + case nullref: + import = get_nullref; + break; case exnref: import = get_exnref; break; @@ -126,9 +136,15 @@ struct InstrumentLocals : public WalkerPass<PostWalker<InstrumentLocals>> { break; case v128: assert(false && "v128 not implemented yet"); + case funcref: + import = set_funcref; + break; case anyref: import = set_anyref; break; + case nullref: + import = set_nullref; + break; case exnref: import = set_exnref; break; @@ -156,10 +172,26 @@ struct InstrumentLocals : public WalkerPass<PostWalker<InstrumentLocals>> { addImport(curr, set_f64, {Type::i32, Type::i32, Type::f64}, Type::f64); if (curr->features.hasReferenceTypes()) { + addImport(curr, + get_funcref, + {Type::i32, Type::i32, Type::funcref}, + Type::funcref); + addImport(curr, + set_funcref, + {Type::i32, Type::i32, Type::funcref}, + Type::funcref); addImport( curr, get_anyref, {Type::i32, Type::i32, Type::anyref}, Type::anyref); addImport( curr, set_anyref, {Type::i32, Type::i32, Type::anyref}, Type::anyref); + addImport(curr, + get_nullref, + {Type::i32, Type::i32, Type::nullref}, + Type::nullref); + addImport(curr, + set_nullref, + {Type::i32, Type::i32, Type::nullref}, + Type::nullref); } if (curr->features.hasExceptionHandling()) { addImport( diff --git a/src/passes/LegalizeJSInterface.cpp b/src/passes/LegalizeJSInterface.cpp index 8c7bc4414..df6651b0d 100644 --- a/src/passes/LegalizeJSInterface.cpp +++ b/src/passes/LegalizeJSInterface.cpp @@ -107,14 +107,43 @@ struct LegalizeJSInterface : public Pass { } } } + if (!illegalImportsToLegal.empty()) { + // Gather functions used in 'ref.func'. They should not be removed. + std::unordered_map<Name, std::atomic<bool>> usedInRefFunc; + + struct RefFuncScanner : public WalkerPass<PostWalker<RefFuncScanner>> { + Module& wasm; + std::unordered_map<Name, std::atomic<bool>>& usedInRefFunc; + + bool isFunctionParallel() override { return true; } + + Pass* create() override { + return new RefFuncScanner(wasm, usedInRefFunc); + } + + RefFuncScanner( + Module& wasm, + std::unordered_map<Name, std::atomic<bool>>& usedInRefFunc) + : wasm(wasm), usedInRefFunc(usedInRefFunc) { + // Fill in unordered_map, as we operate on it in parallel + for (auto& func : wasm.functions) { + usedInRefFunc[func->name]; + } + } + + void visitRefFunc(RefFunc* curr) { usedInRefFunc[curr->func] = true; } + }; + + RefFuncScanner(*module, usedInRefFunc).run(runner, module); for (auto& pair : illegalImportsToLegal) { - module->removeFunction(pair.first); + if (!usedInRefFunc[pair.first]) { + module->removeFunction(pair.first); + } } // fix up imports: call_import of an illegal must be turned to a call of a // legal - struct FixImports : public WalkerPass<PostWalker<FixImports>> { bool isFunctionParallel() override { return true; } diff --git a/src/passes/LocalCSE.cpp b/src/passes/LocalCSE.cpp index 0816bf6ea..b49c92310 100644 --- a/src/passes/LocalCSE.cpp +++ b/src/passes/LocalCSE.cpp @@ -172,9 +172,12 @@ struct LocalCSE : public WalkerPass<LinearExecutionWalker<LocalCSE>> { void handle(Expression* curr) { if (auto* set = curr->dynCast<LocalSet>()) { // Calculate equivalences + auto* func = getFunction(); equivalences.reset(set->index); if (auto* get = set->value->dynCast<LocalGet>()) { - equivalences.add(set->index, get->index); + if (func->getLocalType(set->index) == func->getLocalType(get->index)) { + equivalences.add(set->index, get->index); + } } // consider the value auto* value = set->value; @@ -184,7 +187,7 @@ struct LocalCSE : public WalkerPass<LinearExecutionWalker<LocalCSE>> { if (iter != usables.end()) { // already exists in the table, this is good to reuse auto& info = iter->second; - Type localType = getFunction()->getLocalType(info.index); + Type localType = func->getLocalType(info.index); set->value = Builder(*getModule()).makeLocalGet(info.index, localType); anotherPass = true; diff --git a/src/passes/MergeLocals.cpp b/src/passes/MergeLocals.cpp index 0116753f1..2223594b6 100644 --- a/src/passes/MergeLocals.cpp +++ b/src/passes/MergeLocals.cpp @@ -100,7 +100,8 @@ struct MergeLocals return; } // compute all dependencies - LocalGraph preGraph(getFunction()); + auto* func = getFunction(); + LocalGraph preGraph(func); preGraph.computeInfluences(); // optimize each copy std::unordered_map<LocalSet*, LocalSet*> optimizedToCopy, @@ -119,6 +120,11 @@ struct MergeLocals if (preGraph.getSetses[influencedGet].size() == 1) { // this is ok assert(*preGraph.getSetses[influencedGet].begin() == trivial); + // If local types are different (when one is a subtype of the + // other), don't optimize + if (func->getLocalType(copy->index) != influencedGet->type) { + canOptimizeToCopy = false; + } } else { canOptimizeToCopy = false; break; @@ -152,6 +158,11 @@ struct MergeLocals if (preGraph.getSetses[influencedGet].size() == 1) { // this is ok assert(*preGraph.getSetses[influencedGet].begin() == copy); + // If local types are different (when one is a subtype of the + // other), don't optimize + if (func->getLocalType(trivial->index) != influencedGet->type) { + canOptimizeToTrivial = false; + } } else { canOptimizeToTrivial = false; break; @@ -176,7 +187,7 @@ struct MergeLocals // if one does not work, we need to undo all its siblings (don't extend // the live range unless we are definitely removing a conflict, same // logic as before). - LocalGraph postGraph(getFunction()); + LocalGraph postGraph(func); postGraph.computeInfluences(); for (auto& pair : optimizedToCopy) { auto* copy = pair.first; diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index 6de1d3d00..edd6ba2b6 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -751,12 +751,12 @@ struct OptimizeInstructions // condition, do that auto needCondition = EffectAnalyzer(getPassOptions(), iff->condition).hasSideEffects(); - auto typeIsIdentical = iff->ifTrue->type == iff->type; - if (typeIsIdentical && !needCondition) { + auto isSubType = Type::isSubType(iff->ifTrue->type, iff->type); + if (isSubType && !needCondition) { return iff->ifTrue; } else { Builder builder(*getModule()); - if (typeIsIdentical) { + if (isSubType) { return builder.makeSequence(builder.makeDrop(iff->condition), iff->ifTrue); } else { diff --git a/src/passes/Precompute.cpp b/src/passes/Precompute.cpp index 57a3ab27f..85eb026f9 100644 --- a/src/passes/Precompute.cpp +++ b/src/passes/Precompute.cpp @@ -177,7 +177,7 @@ struct Precompute void visitExpression(Expression* curr) { // TODO: if local.get, only replace with a constant if we don't care about // size...? - if (curr->is<Const>() || curr->is<Nop>()) { + if (Properties::isConstantExpression(curr) || curr->is<Nop>()) { return; } // Until engines implement v128.const and we have SIMD-aware optimizations @@ -208,14 +208,16 @@ struct Precompute return; } } - ret->value = Builder(*getModule()).makeConst(flow.value); + ret->value = Builder(*getModule()).makeConstExpression(flow.value); } else { ret->value = nullptr; } } else { Builder builder(*getModule()); - replaceCurrent(builder.makeReturn( - flow.value.type != none ? builder.makeConst(flow.value) : nullptr)); + replaceCurrent( + builder.makeReturn(flow.value.type != Type::none + ? builder.makeConstExpression(flow.value) + : nullptr)); } return; } @@ -234,7 +236,7 @@ struct Precompute return; } } - br->value = Builder(*getModule()).makeConst(flow.value); + br->value = Builder(*getModule()).makeConstExpression(flow.value); } else { br->value = nullptr; } @@ -243,13 +245,14 @@ struct Precompute Builder builder(*getModule()); replaceCurrent(builder.makeBreak( flow.breakTo, - flow.value.type != none ? builder.makeConst(flow.value) : nullptr)); + flow.value.type != none ? builder.makeConstExpression(flow.value) + : nullptr)); } return; } // this was precomputed if (flow.value.type.isConcrete()) { - replaceCurrent(Builder(*getModule()).makeConst(flow.value)); + replaceCurrent(Builder(*getModule()).makeConstExpression(flow.value)); worked = true; } else { ExpressionManipulator::nop(curr); @@ -350,7 +353,7 @@ private: } else { curr = setValues[set]; } - if (curr.isNull()) { + if (curr.isNone()) { // not a constant, give up value = Literal(); break; diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index 5efd1fd28..51e78c8a7 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -1333,7 +1333,12 @@ struct PrintExpressionContents } restoreNormalColor(o); } - void visitSelect(Select* curr) { prepareColor(o) << "select"; } + void visitSelect(Select* curr) { + prepareColor(o) << "select"; + if (curr->type.isRef()) { + o << " (result " << curr->type << ')'; + } + } void visitDrop(Drop* curr) { printMedium(o, "drop"); } void visitReturn(Return* curr) { printMedium(o, "return"); } void visitHost(Host* curr) { @@ -1346,6 +1351,12 @@ struct PrintExpressionContents break; } } + void visitRefNull(RefNull* curr) { printMedium(o, "ref.null"); } + void visitRefIsNull(RefIsNull* curr) { printMedium(o, "ref.is_null"); } + void visitRefFunc(RefFunc* curr) { + printMedium(o, "ref.func "); + printName(curr->func, o); + } void visitTry(Try* curr) { printMedium(o, "try"); if (curr->type.isConcrete()) { @@ -1852,6 +1863,23 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> { } } } + void visitRefNull(RefNull* curr) { + o << '('; + PrintExpressionContents(currFunction, o).visit(curr); + o << ')'; + } + void visitRefIsNull(RefIsNull* curr) { + o << '('; + PrintExpressionContents(currFunction, o).visit(curr); + incIndent(); + printFullLine(curr->value); + decIndent(); + } + void visitRefFunc(RefFunc* curr) { + o << '('; + PrintExpressionContents(currFunction, o).visit(curr); + o << ')'; + } // try-catch-end is written in the folded wat format as // (try // ... @@ -2434,13 +2462,15 @@ WasmPrinter::printStackInst(StackInst* inst, std::ostream& o, Function* func) { } case StackInst::BlockBegin: case StackInst::IfBegin: - case StackInst::LoopBegin: { + case StackInst::LoopBegin: + case StackInst::TryBegin: { o << getExpressionName(inst->origin); break; } case StackInst::BlockEnd: case StackInst::IfEnd: - case StackInst::LoopEnd: { + case StackInst::LoopEnd: + case StackInst::TryEnd: { o << "end (" << inst->type << ')'; break; } @@ -2448,6 +2478,10 @@ WasmPrinter::printStackInst(StackInst* inst, std::ostream& o, Function* func) { o << "else"; break; } + case StackInst::Catch: { + o << "catch"; + break; + } default: WASM_UNREACHABLE("unexpeted op"); } diff --git a/src/passes/RemoveUnusedModuleElements.cpp b/src/passes/RemoveUnusedModuleElements.cpp index f5000e3a4..21cbc5e5b 100644 --- a/src/passes/RemoveUnusedModuleElements.cpp +++ b/src/passes/RemoveUnusedModuleElements.cpp @@ -116,6 +116,12 @@ struct ReachabilityAnalyzer : public PostWalker<ReachabilityAnalyzer> { usesMemory = true; } } + void visitRefFunc(RefFunc* curr) { + if (reachable.count( + ModuleElement(ModuleElementKind::Function, curr->func)) == 0) { + queue.emplace_back(ModuleElementKind::Function, curr->func); + } + } void visitThrow(Throw* curr) { if (reachable.count(ModuleElement(ModuleElementKind::Event, curr->event)) == 0) { diff --git a/src/passes/SimplifyGlobals.cpp b/src/passes/SimplifyGlobals.cpp index 88f27f8be..b18f726ed 100644 --- a/src/passes/SimplifyGlobals.cpp +++ b/src/passes/SimplifyGlobals.cpp @@ -37,6 +37,7 @@ #include <atomic> #include "ir/effects.h" +#include "ir/properties.h" #include "ir/utils.h" #include "pass.h" #include "wasm-builder.h" @@ -106,8 +107,9 @@ struct ConstantGlobalApplier void visitExpression(Expression* curr) { if (auto* set = curr->dynCast<GlobalSet>()) { - if (auto* c = set->value->dynCast<Const>()) { - currConstantGlobals[set->name] = c->value; + if (Properties::isConstantExpression(set->value)) { + currConstantGlobals[set->name] = + getLiteralFromConstExpression(set->value); } else { currConstantGlobals.erase(set->name); } @@ -116,7 +118,7 @@ struct ConstantGlobalApplier // Check if the global is known to be constant all the time. if (constantGlobals->count(get->name)) { auto* global = getModule()->getGlobal(get->name); - assert(global->init->is<Const>()); + assert(Properties::isConstantExpression(global->init)); replaceCurrent(ExpressionManipulator::copy(global->init, *getModule())); replaced = true; return; @@ -125,7 +127,7 @@ struct ConstantGlobalApplier auto iter = currConstantGlobals.find(get->name); if (iter != currConstantGlobals.end()) { Builder builder(*getModule()); - replaceCurrent(builder.makeConst(iter->second)); + replaceCurrent(builder.makeConstExpression(iter->second)); replaced = true; } return; @@ -249,13 +251,14 @@ struct SimplifyGlobals : public Pass { std::map<Name, Literal> constantGlobals; for (auto& global : module->globals) { if (!global->imported()) { - if (auto* c = global->init->dynCast<Const>()) { - constantGlobals[global->name] = c->value; + if (Properties::isConstantExpression(global->init)) { + constantGlobals[global->name] = + getLiteralFromConstExpression(global->init); } else if (auto* get = global->init->dynCast<GlobalGet>()) { auto iter = constantGlobals.find(get->name); if (iter != constantGlobals.end()) { Builder builder(*module); - global->init = builder.makeConst(iter->second); + global->init = builder.makeConstExpression(iter->second); } } } @@ -268,7 +271,7 @@ struct SimplifyGlobals : public Pass { NameSet constantGlobals; for (auto& global : module->globals) { if (!global->mutable_ && !global->imported() && - global->init->is<Const>()) { + Properties::isConstantExpression(global->init)) { constantGlobals.insert(global->name); } } diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp index a3fa4a34d..a952f8a38 100644 --- a/src/passes/SimplifyLocals.cpp +++ b/src/passes/SimplifyLocals.cpp @@ -546,7 +546,6 @@ struct SimplifyLocals auto* blockLocalSetPointer = sinkables.at(sharedIndex).item; auto* value = (*blockLocalSetPointer)->template cast<LocalSet>()->value; block->list[block->list.size() - 1] = value; - block->type = value->type; ExpressionManipulator::nop(*blockLocalSetPointer); for (size_t j = 0; j < breaks.size(); j++) { // move break local.set's value to the break @@ -577,6 +576,7 @@ struct SimplifyLocals this->replaceCurrent(newLocalSet); sinkables.clear(); anotherCycle = true; + block->finalize(); } // optimize local.sets from both sides of an if into a return value @@ -915,6 +915,7 @@ struct SimplifyLocals void visitLocalSet(LocalSet* curr) { // Remove trivial copies, even through a tee auto* value = curr->value; + Function* func = this->getFunction(); while (auto* subSet = value->dynCast<LocalSet>()) { value = subSet->value; } @@ -929,7 +930,8 @@ struct SimplifyLocals } anotherCycle = true; } - } else { + } else if (func->getLocalType(curr->index) == + func->getLocalType(get->index)) { // There is a new equivalence now. equivalences.reset(curr->index); equivalences.add(curr->index, get->index); diff --git a/src/passes/opt-utils.h b/src/passes/opt-utils.h index 93fac137f..7912a7d92 100644 --- a/src/passes/opt-utils.h +++ b/src/passes/opt-utils.h @@ -54,19 +54,22 @@ inline void optimizeAfterInlining(std::unordered_set<Function*>& funcs, module->updateMaps(); } -struct CallTargetReplacer : public WalkerPass<PostWalker<CallTargetReplacer>> { +struct FunctionRefReplacer + : public WalkerPass<PostWalker<FunctionRefReplacer>> { bool isFunctionParallel() override { return true; } using MaybeReplace = std::function<void(Name&)>; - CallTargetReplacer(MaybeReplace maybeReplace) : maybeReplace(maybeReplace) {} + FunctionRefReplacer(MaybeReplace maybeReplace) : maybeReplace(maybeReplace) {} - CallTargetReplacer* create() override { - return new CallTargetReplacer(maybeReplace); + FunctionRefReplacer* create() override { + return new FunctionRefReplacer(maybeReplace); } void visitCall(Call* curr) { maybeReplace(curr->target); } + void visitRefFunc(RefFunc* curr) { maybeReplace(curr->func); } + private: MaybeReplace maybeReplace; }; @@ -81,7 +84,7 @@ inline void replaceFunctions(PassRunner* runner, } }; // replace direct calls - CallTargetReplacer(maybeReplace).run(runner, &module); + FunctionRefReplacer(maybeReplace).run(runner, &module); // replace in table for (auto& segment : module.table.segments) { for (auto& name : segment.data) { |