summaryrefslogtreecommitdiff
path: root/src/passes
diff options
context:
space:
mode:
Diffstat (limited to 'src/passes')
-rw-r--r--src/passes/ConstHoisting.cpp9
-rw-r--r--src/passes/DeadCodeElimination.cpp6
-rw-r--r--src/passes/Flatten.cpp37
-rw-r--r--src/passes/FuncCastEmulation.cpp16
-rw-r--r--src/passes/Inlining.cpp17
-rw-r--r--src/passes/InstrumentLocals.cpp32
-rw-r--r--src/passes/LegalizeJSInterface.cpp33
-rw-r--r--src/passes/LocalCSE.cpp7
-rw-r--r--src/passes/MergeLocals.cpp15
-rw-r--r--src/passes/OptimizeInstructions.cpp6
-rw-r--r--src/passes/Precompute.cpp19
-rw-r--r--src/passes/Print.cpp40
-rw-r--r--src/passes/RemoveUnusedModuleElements.cpp6
-rw-r--r--src/passes/SimplifyGlobals.cpp19
-rw-r--r--src/passes/SimplifyLocals.cpp6
-rw-r--r--src/passes/opt-utils.h13
16 files changed, 227 insertions, 54 deletions
diff --git a/src/passes/ConstHoisting.cpp b/src/passes/ConstHoisting.cpp
index dbb3853d8..4e8cd9910 100644
--- a/src/passes/ConstHoisting.cpp
+++ b/src/passes/ConstHoisting.cpp
@@ -91,9 +91,12 @@ private:
size = value.type.getByteSize();
break;
}
- case v128: // v128 not implemented yet
- case anyref: // anyref cannot have literals
- case exnref: { // exnref cannot have literals
+ // not implemented yet
+ case v128:
+ case funcref:
+ case anyref:
+ case nullref:
+ case exnref: {
return false;
}
case none:
diff --git a/src/passes/DeadCodeElimination.cpp b/src/passes/DeadCodeElimination.cpp
index be6f92ffa..7d5385a83 100644
--- a/src/passes/DeadCodeElimination.cpp
+++ b/src/passes/DeadCodeElimination.cpp
@@ -347,6 +347,12 @@ struct DeadCodeElimination
DELEGATE(Push);
case Expression::Id::PopId:
DELEGATE(Pop);
+ case Expression::Id::RefNullId:
+ DELEGATE(RefNull);
+ case Expression::Id::RefIsNullId:
+ DELEGATE(RefIsNull);
+ case Expression::Id::RefFuncId:
+ DELEGATE(RefFunc);
case Expression::Id::TryId:
DELEGATE(Try);
case Expression::Id::ThrowId:
diff --git a/src/passes/Flatten.cpp b/src/passes/Flatten.cpp
index f5115567b..fda8e3f80 100644
--- a/src/passes/Flatten.cpp
+++ b/src/passes/Flatten.cpp
@@ -21,6 +21,7 @@
#include <ir/branch-utils.h>
#include <ir/effects.h>
#include <ir/flat.h>
+#include <ir/properties.h>
#include <ir/utils.h>
#include <pass.h>
#include <wasm-builder.h>
@@ -61,7 +62,9 @@ struct Flatten
std::vector<Expression*> ourPreludes;
Builder builder(*getModule());
- if (curr->is<Const>() || curr->is<Nop>() || curr->is<Unreachable>()) {
+ // Nothing to do for constants, nop, and unreachable
+ if (Properties::isConstantExpression(curr) || curr->is<Nop>() ||
+ curr->is<Unreachable>()) {
return;
}
@@ -194,8 +197,37 @@ struct Flatten
auto type = br->value->type;
if (type.isConcrete()) {
// we are sending a value. use a local instead
- Index temp = getTempForBreakTarget(br->name, type);
+ Type blockType = findBreakTarget(br->name)->type;
+ Index temp = getTempForBreakTarget(br->name, blockType);
ourPreludes.push_back(builder.makeLocalSet(temp, br->value));
+
+ // br_if leaves a value on the stack if not taken, which later can
+ // be the last element of the enclosing innermost block and flow
+ // out. The local we created using 'getTempForBreakTarget' returns
+ // the return type of the block this branch is targetting, which may
+ // not be the same with the innermost block's return type. For
+ // example,
+ // (block $any (result anyref)
+ // (block (result nullref)
+ // (local.tee $0
+ // (br_if $any
+ // (ref.null)
+ // (i32.const 0)
+ // )
+ // )
+ // )
+ // )
+ // In this case we need two locals to store (ref.null); one with
+ // anyref type that's for the target block ($label0) and one more
+ // with nullref type in case for flowing out. Here we create the
+ // second 'flowing out' local in case two block's types are
+ // different.
+ if (type != blockType) {
+ temp = builder.addVar(getFunction(), type);
+ ourPreludes.push_back(builder.makeLocalSet(
+ temp, ExpressionManipulator::copy(br->value, *getModule())));
+ }
+
if (br->condition) {
// the value must also flow out
ourPreludes.push_back(br);
@@ -239,6 +271,7 @@ struct Flatten
}
}
}
+ // TODO Handle br_on_exn
// continue for general handling of everything, control flow or otherwise
curr = getCurrent(); // we may have replaced it
diff --git a/src/passes/FuncCastEmulation.cpp b/src/passes/FuncCastEmulation.cpp
index 729a4a6c3..9d5109a83 100644
--- a/src/passes/FuncCastEmulation.cpp
+++ b/src/passes/FuncCastEmulation.cpp
@@ -65,11 +65,11 @@ static Expression* toABI(Expression* value, Module* module) {
case v128: {
WASM_UNREACHABLE("v128 not implemented yet");
}
- case anyref: {
- WASM_UNREACHABLE("anyref cannot be converted to i64");
- }
+ case funcref:
+ case anyref:
+ case nullref:
case exnref: {
- WASM_UNREACHABLE("exnref cannot be converted to i64");
+ WASM_UNREACHABLE("reference types cannot be converted to i64");
}
case none: {
// the value is none, but we need a value here
@@ -108,11 +108,11 @@ static Expression* fromABI(Expression* value, Type type, Module* module) {
case v128: {
WASM_UNREACHABLE("v128 not implemented yet");
}
- case anyref: {
- WASM_UNREACHABLE("anyref cannot be converted from i64");
- }
+ case funcref:
+ case anyref:
+ case nullref:
case exnref: {
- WASM_UNREACHABLE("exnref cannot be converted from i64");
+ WASM_UNREACHABLE("reference types cannot be converted from i64");
}
case none: {
value = builder.makeDrop(value);
diff --git a/src/passes/Inlining.cpp b/src/passes/Inlining.cpp
index db1db5971..c43d41e7f 100644
--- a/src/passes/Inlining.cpp
+++ b/src/passes/Inlining.cpp
@@ -46,13 +46,13 @@ namespace wasm {
// Useful into on a function, helping us decide if we can inline it
struct FunctionInfo {
- std::atomic<Index> calls;
+ std::atomic<Index> refs;
Index size;
std::atomic<bool> lightweight;
bool usedGlobally; // in a table or export
FunctionInfo() {
- calls = 0;
+ refs = 0;
size = 0;
lightweight = true;
usedGlobally = false;
@@ -75,7 +75,7 @@ struct FunctionInfo {
// FIXME: move this check to be first in this function, since we should
// return true if oneCallerInlineMaxSize is bigger than
// flexibleInlineMaxSize (which it typically should be).
- if (calls == 1 && !usedGlobally &&
+ if (refs == 1 && !usedGlobally &&
size <= options.inlining.oneCallerInlineMaxSize) {
return true;
}
@@ -108,11 +108,16 @@ struct FunctionInfoScanner
void visitCall(Call* curr) {
// can't add a new element in parallel
assert(infos->count(curr->target) > 0);
- (*infos)[curr->target].calls++;
+ (*infos)[curr->target].refs++;
// having a call is not lightweight
(*infos)[getFunction()->name].lightweight = false;
}
+ void visitRefFunc(RefFunc* curr) {
+ assert(infos->count(curr->func) > 0);
+ (*infos)[curr->func].refs++;
+ }
+
void visitFunction(Function* curr) {
(*infos)[curr->name].size = Measurer::measure(curr->body);
}
@@ -374,7 +379,7 @@ struct Inlining : public Pass {
doInlining(module, func.get(), action);
inlinedUses[inlinedName]++;
inlinedInto.insert(func.get());
- assert(inlinedUses[inlinedName] <= infos[inlinedName].calls);
+ assert(inlinedUses[inlinedName] <= infos[inlinedName].refs);
}
}
// anything we inlined into may now have non-unique label names, fix it up
@@ -388,7 +393,7 @@ struct Inlining : public Pass {
module->removeFunctions([&](Function* func) {
auto name = func->name;
auto& info = infos[name];
- return inlinedUses.count(name) && inlinedUses[name] == info.calls &&
+ return inlinedUses.count(name) && inlinedUses[name] == info.refs &&
!info.usedGlobally;
});
// return whether we did any work
diff --git a/src/passes/InstrumentLocals.cpp b/src/passes/InstrumentLocals.cpp
index 407903219..ae35ec2d1 100644
--- a/src/passes/InstrumentLocals.cpp
+++ b/src/passes/InstrumentLocals.cpp
@@ -56,14 +56,18 @@ Name get_i32("get_i32");
Name get_i64("get_i64");
Name get_f32("get_f32");
Name get_f64("get_f64");
+Name get_funcref("get_funcref");
Name get_anyref("get_anyref");
+Name get_nullref("get_nullref");
Name get_exnref("get_exnref");
Name set_i32("set_i32");
Name set_i64("set_i64");
Name set_f32("set_f32");
Name set_f64("set_f64");
+Name set_funcref("set_funcref");
Name set_anyref("set_anyref");
+Name set_nullref("set_nullref");
Name set_exnref("set_exnref");
struct InstrumentLocals : public WalkerPass<PostWalker<InstrumentLocals>> {
@@ -84,9 +88,15 @@ struct InstrumentLocals : public WalkerPass<PostWalker<InstrumentLocals>> {
break;
case v128:
assert(false && "v128 not implemented yet");
+ case funcref:
+ import = get_funcref;
+ break;
case anyref:
import = get_anyref;
break;
+ case nullref:
+ import = get_nullref;
+ break;
case exnref:
import = get_exnref;
break;
@@ -126,9 +136,15 @@ struct InstrumentLocals : public WalkerPass<PostWalker<InstrumentLocals>> {
break;
case v128:
assert(false && "v128 not implemented yet");
+ case funcref:
+ import = set_funcref;
+ break;
case anyref:
import = set_anyref;
break;
+ case nullref:
+ import = set_nullref;
+ break;
case exnref:
import = set_exnref;
break;
@@ -156,10 +172,26 @@ struct InstrumentLocals : public WalkerPass<PostWalker<InstrumentLocals>> {
addImport(curr, set_f64, {Type::i32, Type::i32, Type::f64}, Type::f64);
if (curr->features.hasReferenceTypes()) {
+ addImport(curr,
+ get_funcref,
+ {Type::i32, Type::i32, Type::funcref},
+ Type::funcref);
+ addImport(curr,
+ set_funcref,
+ {Type::i32, Type::i32, Type::funcref},
+ Type::funcref);
addImport(
curr, get_anyref, {Type::i32, Type::i32, Type::anyref}, Type::anyref);
addImport(
curr, set_anyref, {Type::i32, Type::i32, Type::anyref}, Type::anyref);
+ addImport(curr,
+ get_nullref,
+ {Type::i32, Type::i32, Type::nullref},
+ Type::nullref);
+ addImport(curr,
+ set_nullref,
+ {Type::i32, Type::i32, Type::nullref},
+ Type::nullref);
}
if (curr->features.hasExceptionHandling()) {
addImport(
diff --git a/src/passes/LegalizeJSInterface.cpp b/src/passes/LegalizeJSInterface.cpp
index 8c7bc4414..df6651b0d 100644
--- a/src/passes/LegalizeJSInterface.cpp
+++ b/src/passes/LegalizeJSInterface.cpp
@@ -107,14 +107,43 @@ struct LegalizeJSInterface : public Pass {
}
}
}
+
if (!illegalImportsToLegal.empty()) {
+ // Gather functions used in 'ref.func'. They should not be removed.
+ std::unordered_map<Name, std::atomic<bool>> usedInRefFunc;
+
+ struct RefFuncScanner : public WalkerPass<PostWalker<RefFuncScanner>> {
+ Module& wasm;
+ std::unordered_map<Name, std::atomic<bool>>& usedInRefFunc;
+
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override {
+ return new RefFuncScanner(wasm, usedInRefFunc);
+ }
+
+ RefFuncScanner(
+ Module& wasm,
+ std::unordered_map<Name, std::atomic<bool>>& usedInRefFunc)
+ : wasm(wasm), usedInRefFunc(usedInRefFunc) {
+ // Fill in unordered_map, as we operate on it in parallel
+ for (auto& func : wasm.functions) {
+ usedInRefFunc[func->name];
+ }
+ }
+
+ void visitRefFunc(RefFunc* curr) { usedInRefFunc[curr->func] = true; }
+ };
+
+ RefFuncScanner(*module, usedInRefFunc).run(runner, module);
for (auto& pair : illegalImportsToLegal) {
- module->removeFunction(pair.first);
+ if (!usedInRefFunc[pair.first]) {
+ module->removeFunction(pair.first);
+ }
}
// fix up imports: call_import of an illegal must be turned to a call of a
// legal
-
struct FixImports : public WalkerPass<PostWalker<FixImports>> {
bool isFunctionParallel() override { return true; }
diff --git a/src/passes/LocalCSE.cpp b/src/passes/LocalCSE.cpp
index 0816bf6ea..b49c92310 100644
--- a/src/passes/LocalCSE.cpp
+++ b/src/passes/LocalCSE.cpp
@@ -172,9 +172,12 @@ struct LocalCSE : public WalkerPass<LinearExecutionWalker<LocalCSE>> {
void handle(Expression* curr) {
if (auto* set = curr->dynCast<LocalSet>()) {
// Calculate equivalences
+ auto* func = getFunction();
equivalences.reset(set->index);
if (auto* get = set->value->dynCast<LocalGet>()) {
- equivalences.add(set->index, get->index);
+ if (func->getLocalType(set->index) == func->getLocalType(get->index)) {
+ equivalences.add(set->index, get->index);
+ }
}
// consider the value
auto* value = set->value;
@@ -184,7 +187,7 @@ struct LocalCSE : public WalkerPass<LinearExecutionWalker<LocalCSE>> {
if (iter != usables.end()) {
// already exists in the table, this is good to reuse
auto& info = iter->second;
- Type localType = getFunction()->getLocalType(info.index);
+ Type localType = func->getLocalType(info.index);
set->value =
Builder(*getModule()).makeLocalGet(info.index, localType);
anotherPass = true;
diff --git a/src/passes/MergeLocals.cpp b/src/passes/MergeLocals.cpp
index 0116753f1..2223594b6 100644
--- a/src/passes/MergeLocals.cpp
+++ b/src/passes/MergeLocals.cpp
@@ -100,7 +100,8 @@ struct MergeLocals
return;
}
// compute all dependencies
- LocalGraph preGraph(getFunction());
+ auto* func = getFunction();
+ LocalGraph preGraph(func);
preGraph.computeInfluences();
// optimize each copy
std::unordered_map<LocalSet*, LocalSet*> optimizedToCopy,
@@ -119,6 +120,11 @@ struct MergeLocals
if (preGraph.getSetses[influencedGet].size() == 1) {
// this is ok
assert(*preGraph.getSetses[influencedGet].begin() == trivial);
+ // If local types are different (when one is a subtype of the
+ // other), don't optimize
+ if (func->getLocalType(copy->index) != influencedGet->type) {
+ canOptimizeToCopy = false;
+ }
} else {
canOptimizeToCopy = false;
break;
@@ -152,6 +158,11 @@ struct MergeLocals
if (preGraph.getSetses[influencedGet].size() == 1) {
// this is ok
assert(*preGraph.getSetses[influencedGet].begin() == copy);
+ // If local types are different (when one is a subtype of the
+ // other), don't optimize
+ if (func->getLocalType(trivial->index) != influencedGet->type) {
+ canOptimizeToTrivial = false;
+ }
} else {
canOptimizeToTrivial = false;
break;
@@ -176,7 +187,7 @@ struct MergeLocals
// if one does not work, we need to undo all its siblings (don't extend
// the live range unless we are definitely removing a conflict, same
// logic as before).
- LocalGraph postGraph(getFunction());
+ LocalGraph postGraph(func);
postGraph.computeInfluences();
for (auto& pair : optimizedToCopy) {
auto* copy = pair.first;
diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp
index 6de1d3d00..edd6ba2b6 100644
--- a/src/passes/OptimizeInstructions.cpp
+++ b/src/passes/OptimizeInstructions.cpp
@@ -751,12 +751,12 @@ struct OptimizeInstructions
// condition, do that
auto needCondition =
EffectAnalyzer(getPassOptions(), iff->condition).hasSideEffects();
- auto typeIsIdentical = iff->ifTrue->type == iff->type;
- if (typeIsIdentical && !needCondition) {
+ auto isSubType = Type::isSubType(iff->ifTrue->type, iff->type);
+ if (isSubType && !needCondition) {
return iff->ifTrue;
} else {
Builder builder(*getModule());
- if (typeIsIdentical) {
+ if (isSubType) {
return builder.makeSequence(builder.makeDrop(iff->condition),
iff->ifTrue);
} else {
diff --git a/src/passes/Precompute.cpp b/src/passes/Precompute.cpp
index 57a3ab27f..85eb026f9 100644
--- a/src/passes/Precompute.cpp
+++ b/src/passes/Precompute.cpp
@@ -177,7 +177,7 @@ struct Precompute
void visitExpression(Expression* curr) {
// TODO: if local.get, only replace with a constant if we don't care about
// size...?
- if (curr->is<Const>() || curr->is<Nop>()) {
+ if (Properties::isConstantExpression(curr) || curr->is<Nop>()) {
return;
}
// Until engines implement v128.const and we have SIMD-aware optimizations
@@ -208,14 +208,16 @@ struct Precompute
return;
}
}
- ret->value = Builder(*getModule()).makeConst(flow.value);
+ ret->value = Builder(*getModule()).makeConstExpression(flow.value);
} else {
ret->value = nullptr;
}
} else {
Builder builder(*getModule());
- replaceCurrent(builder.makeReturn(
- flow.value.type != none ? builder.makeConst(flow.value) : nullptr));
+ replaceCurrent(
+ builder.makeReturn(flow.value.type != Type::none
+ ? builder.makeConstExpression(flow.value)
+ : nullptr));
}
return;
}
@@ -234,7 +236,7 @@ struct Precompute
return;
}
}
- br->value = Builder(*getModule()).makeConst(flow.value);
+ br->value = Builder(*getModule()).makeConstExpression(flow.value);
} else {
br->value = nullptr;
}
@@ -243,13 +245,14 @@ struct Precompute
Builder builder(*getModule());
replaceCurrent(builder.makeBreak(
flow.breakTo,
- flow.value.type != none ? builder.makeConst(flow.value) : nullptr));
+ flow.value.type != none ? builder.makeConstExpression(flow.value)
+ : nullptr));
}
return;
}
// this was precomputed
if (flow.value.type.isConcrete()) {
- replaceCurrent(Builder(*getModule()).makeConst(flow.value));
+ replaceCurrent(Builder(*getModule()).makeConstExpression(flow.value));
worked = true;
} else {
ExpressionManipulator::nop(curr);
@@ -350,7 +353,7 @@ private:
} else {
curr = setValues[set];
}
- if (curr.isNull()) {
+ if (curr.isNone()) {
// not a constant, give up
value = Literal();
break;
diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp
index 5efd1fd28..51e78c8a7 100644
--- a/src/passes/Print.cpp
+++ b/src/passes/Print.cpp
@@ -1333,7 +1333,12 @@ struct PrintExpressionContents
}
restoreNormalColor(o);
}
- void visitSelect(Select* curr) { prepareColor(o) << "select"; }
+ void visitSelect(Select* curr) {
+ prepareColor(o) << "select";
+ if (curr->type.isRef()) {
+ o << " (result " << curr->type << ')';
+ }
+ }
void visitDrop(Drop* curr) { printMedium(o, "drop"); }
void visitReturn(Return* curr) { printMedium(o, "return"); }
void visitHost(Host* curr) {
@@ -1346,6 +1351,12 @@ struct PrintExpressionContents
break;
}
}
+ void visitRefNull(RefNull* curr) { printMedium(o, "ref.null"); }
+ void visitRefIsNull(RefIsNull* curr) { printMedium(o, "ref.is_null"); }
+ void visitRefFunc(RefFunc* curr) {
+ printMedium(o, "ref.func ");
+ printName(curr->func, o);
+ }
void visitTry(Try* curr) {
printMedium(o, "try");
if (curr->type.isConcrete()) {
@@ -1852,6 +1863,23 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> {
}
}
}
+ void visitRefNull(RefNull* curr) {
+ o << '(';
+ PrintExpressionContents(currFunction, o).visit(curr);
+ o << ')';
+ }
+ void visitRefIsNull(RefIsNull* curr) {
+ o << '(';
+ PrintExpressionContents(currFunction, o).visit(curr);
+ incIndent();
+ printFullLine(curr->value);
+ decIndent();
+ }
+ void visitRefFunc(RefFunc* curr) {
+ o << '(';
+ PrintExpressionContents(currFunction, o).visit(curr);
+ o << ')';
+ }
// try-catch-end is written in the folded wat format as
// (try
// ...
@@ -2434,13 +2462,15 @@ WasmPrinter::printStackInst(StackInst* inst, std::ostream& o, Function* func) {
}
case StackInst::BlockBegin:
case StackInst::IfBegin:
- case StackInst::LoopBegin: {
+ case StackInst::LoopBegin:
+ case StackInst::TryBegin: {
o << getExpressionName(inst->origin);
break;
}
case StackInst::BlockEnd:
case StackInst::IfEnd:
- case StackInst::LoopEnd: {
+ case StackInst::LoopEnd:
+ case StackInst::TryEnd: {
o << "end (" << inst->type << ')';
break;
}
@@ -2448,6 +2478,10 @@ WasmPrinter::printStackInst(StackInst* inst, std::ostream& o, Function* func) {
o << "else";
break;
}
+ case StackInst::Catch: {
+ o << "catch";
+ break;
+ }
default:
WASM_UNREACHABLE("unexpeted op");
}
diff --git a/src/passes/RemoveUnusedModuleElements.cpp b/src/passes/RemoveUnusedModuleElements.cpp
index f5000e3a4..21cbc5e5b 100644
--- a/src/passes/RemoveUnusedModuleElements.cpp
+++ b/src/passes/RemoveUnusedModuleElements.cpp
@@ -116,6 +116,12 @@ struct ReachabilityAnalyzer : public PostWalker<ReachabilityAnalyzer> {
usesMemory = true;
}
}
+ void visitRefFunc(RefFunc* curr) {
+ if (reachable.count(
+ ModuleElement(ModuleElementKind::Function, curr->func)) == 0) {
+ queue.emplace_back(ModuleElementKind::Function, curr->func);
+ }
+ }
void visitThrow(Throw* curr) {
if (reachable.count(ModuleElement(ModuleElementKind::Event, curr->event)) ==
0) {
diff --git a/src/passes/SimplifyGlobals.cpp b/src/passes/SimplifyGlobals.cpp
index 88f27f8be..b18f726ed 100644
--- a/src/passes/SimplifyGlobals.cpp
+++ b/src/passes/SimplifyGlobals.cpp
@@ -37,6 +37,7 @@
#include <atomic>
#include "ir/effects.h"
+#include "ir/properties.h"
#include "ir/utils.h"
#include "pass.h"
#include "wasm-builder.h"
@@ -106,8 +107,9 @@ struct ConstantGlobalApplier
void visitExpression(Expression* curr) {
if (auto* set = curr->dynCast<GlobalSet>()) {
- if (auto* c = set->value->dynCast<Const>()) {
- currConstantGlobals[set->name] = c->value;
+ if (Properties::isConstantExpression(set->value)) {
+ currConstantGlobals[set->name] =
+ getLiteralFromConstExpression(set->value);
} else {
currConstantGlobals.erase(set->name);
}
@@ -116,7 +118,7 @@ struct ConstantGlobalApplier
// Check if the global is known to be constant all the time.
if (constantGlobals->count(get->name)) {
auto* global = getModule()->getGlobal(get->name);
- assert(global->init->is<Const>());
+ assert(Properties::isConstantExpression(global->init));
replaceCurrent(ExpressionManipulator::copy(global->init, *getModule()));
replaced = true;
return;
@@ -125,7 +127,7 @@ struct ConstantGlobalApplier
auto iter = currConstantGlobals.find(get->name);
if (iter != currConstantGlobals.end()) {
Builder builder(*getModule());
- replaceCurrent(builder.makeConst(iter->second));
+ replaceCurrent(builder.makeConstExpression(iter->second));
replaced = true;
}
return;
@@ -249,13 +251,14 @@ struct SimplifyGlobals : public Pass {
std::map<Name, Literal> constantGlobals;
for (auto& global : module->globals) {
if (!global->imported()) {
- if (auto* c = global->init->dynCast<Const>()) {
- constantGlobals[global->name] = c->value;
+ if (Properties::isConstantExpression(global->init)) {
+ constantGlobals[global->name] =
+ getLiteralFromConstExpression(global->init);
} else if (auto* get = global->init->dynCast<GlobalGet>()) {
auto iter = constantGlobals.find(get->name);
if (iter != constantGlobals.end()) {
Builder builder(*module);
- global->init = builder.makeConst(iter->second);
+ global->init = builder.makeConstExpression(iter->second);
}
}
}
@@ -268,7 +271,7 @@ struct SimplifyGlobals : public Pass {
NameSet constantGlobals;
for (auto& global : module->globals) {
if (!global->mutable_ && !global->imported() &&
- global->init->is<Const>()) {
+ Properties::isConstantExpression(global->init)) {
constantGlobals.insert(global->name);
}
}
diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp
index a3fa4a34d..a952f8a38 100644
--- a/src/passes/SimplifyLocals.cpp
+++ b/src/passes/SimplifyLocals.cpp
@@ -546,7 +546,6 @@ struct SimplifyLocals
auto* blockLocalSetPointer = sinkables.at(sharedIndex).item;
auto* value = (*blockLocalSetPointer)->template cast<LocalSet>()->value;
block->list[block->list.size() - 1] = value;
- block->type = value->type;
ExpressionManipulator::nop(*blockLocalSetPointer);
for (size_t j = 0; j < breaks.size(); j++) {
// move break local.set's value to the break
@@ -577,6 +576,7 @@ struct SimplifyLocals
this->replaceCurrent(newLocalSet);
sinkables.clear();
anotherCycle = true;
+ block->finalize();
}
// optimize local.sets from both sides of an if into a return value
@@ -915,6 +915,7 @@ struct SimplifyLocals
void visitLocalSet(LocalSet* curr) {
// Remove trivial copies, even through a tee
auto* value = curr->value;
+ Function* func = this->getFunction();
while (auto* subSet = value->dynCast<LocalSet>()) {
value = subSet->value;
}
@@ -929,7 +930,8 @@ struct SimplifyLocals
}
anotherCycle = true;
}
- } else {
+ } else if (func->getLocalType(curr->index) ==
+ func->getLocalType(get->index)) {
// There is a new equivalence now.
equivalences.reset(curr->index);
equivalences.add(curr->index, get->index);
diff --git a/src/passes/opt-utils.h b/src/passes/opt-utils.h
index 93fac137f..7912a7d92 100644
--- a/src/passes/opt-utils.h
+++ b/src/passes/opt-utils.h
@@ -54,19 +54,22 @@ inline void optimizeAfterInlining(std::unordered_set<Function*>& funcs,
module->updateMaps();
}
-struct CallTargetReplacer : public WalkerPass<PostWalker<CallTargetReplacer>> {
+struct FunctionRefReplacer
+ : public WalkerPass<PostWalker<FunctionRefReplacer>> {
bool isFunctionParallel() override { return true; }
using MaybeReplace = std::function<void(Name&)>;
- CallTargetReplacer(MaybeReplace maybeReplace) : maybeReplace(maybeReplace) {}
+ FunctionRefReplacer(MaybeReplace maybeReplace) : maybeReplace(maybeReplace) {}
- CallTargetReplacer* create() override {
- return new CallTargetReplacer(maybeReplace);
+ FunctionRefReplacer* create() override {
+ return new FunctionRefReplacer(maybeReplace);
}
void visitCall(Call* curr) { maybeReplace(curr->target); }
+ void visitRefFunc(RefFunc* curr) { maybeReplace(curr->func); }
+
private:
MaybeReplace maybeReplace;
};
@@ -81,7 +84,7 @@ inline void replaceFunctions(PassRunner* runner,
}
};
// replace direct calls
- CallTargetReplacer(maybeReplace).run(runner, &module);
+ FunctionRefReplacer(maybeReplace).run(runner, &module);
// replace in table
for (auto& segment : module.table.segments) {
for (auto& name : segment.data) {