diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/ir/properties.h | 23 | ||||
-rw-r--r-- | src/passes/DeNaN.cpp | 46 |
2 files changed, 67 insertions, 2 deletions
diff --git a/src/ir/properties.h b/src/ir/properties.h index 094d90bd6..0c6824e4a 100644 --- a/src/ir/properties.h +++ b/src/ir/properties.h @@ -200,7 +200,9 @@ inline Index getZeroExtBits(Expression* curr) { } // Returns a falling-through value, that is, it looks through a local.tee -// and other operations that receive a value and let it flow through them. +// and other operations that receive a value and let it flow through them. If +// there is no value falling through, returns the node itself (as that is the +// value that trivially falls through, with 0 steps in the middle). inline Expression* getFallthrough(Expression* curr, const PassOptions& passOptions, FeatureSet features) { @@ -241,6 +243,25 @@ inline Expression* getFallthrough(Expression* curr, return curr; } +// Returns whether the resulting value here must fall through without being +// modified. For example, a tee always does so. That is, this returns false if +// and only if the return value may have some computation performed on it to +// change it from the inputs the instruction receives. +// This differs from getFallthrough() which returns a single value that falls +// through - here if more than one value can fall through, like in if-else, +// we can return true. That is, there we care about a value falling through and +// for us to get that actual value to look at; here we just care whether the +// value falls through without being changed, even if it might be one of +// several options. +inline bool isResultFallthrough(Expression* curr) { + // Note that we don't check if there is a return value here; the node may be + // unreachable, for example, but then there is no meaningful answer to give + // anyhow. + return curr->is<LocalSet>() || curr->is<Block>() || curr->is<If>() || + curr->is<Loop>() || curr->is<Try>() || curr->is<Select>() || + curr->is<Break>(); +} + } // namespace Properties } // namespace wasm diff --git a/src/passes/DeNaN.cpp b/src/passes/DeNaN.cpp index 044c13c86..c894cb3f1 100644 --- a/src/passes/DeNaN.cpp +++ b/src/passes/DeNaN.cpp @@ -22,6 +22,7 @@ // differ on wasm's nondeterminism around NaNs. // +#include "ir/properties.h" #include "pass.h" #include "wasm-builder.h" #include "wasm.h" @@ -33,7 +34,18 @@ struct DeNaN : public WalkerPass< void visitExpression(Expression* expr) { // If the expression returns a floating-point value, ensure it is not a // NaN. If we can do this at compile time, do it now, which is useful for - // initializations of global (which we can't do a function call in). + // initializations of global (which we can't do a function call in). Note + // that we don't instrument local.gets, which would cause problems if we + // ran this pass more than once (the added functions use gets, and we don't + // want to instrument them). + if (expr->is<LocalGet>()) { + return; + } + // If the result just falls through without being modified, then we've + // already fixed it up earlier. + if (Properties::isResultFallthrough(expr)) { + return; + } Builder builder(*getModule()); Expression* replacement = nullptr; auto* c = expr->dynCast<Const>(); @@ -61,6 +73,38 @@ struct DeNaN : public WalkerPass< } } + void visitFunction(Function* func) { + if (func->imported()) { + return; + } + // Instrument all locals as they enter the function. + Builder builder(*getModule()); + std::vector<Expression*> fixes; + auto num = func->getNumParams(); + for (Index i = 0; i < num; i++) { + if (func->getLocalType(i) == Type::f32) { + fixes.push_back(builder.makeLocalSet( + i, + builder.makeCall( + "deNan32", {builder.makeLocalGet(i, Type::f32)}, Type::f32))); + } else if (func->getLocalType(i) == Type::f64) { + fixes.push_back(builder.makeLocalSet( + i, + builder.makeCall( + "deNan64", {builder.makeLocalGet(i, Type::f64)}, Type::f64))); + } + } + if (!fixes.empty()) { + fixes.push_back(func->body); + func->body = builder.makeBlock(fixes); + // Merge blocks so we don't add an unnecessary one. + PassRunner runner(getModule(), getPassOptions()); + runner.setIsNested(true); + runner.add("merge-blocks"); + runner.run(); + } + } + void visitModule(Module* module) { // Add helper functions. Builder builder(*module); |