diff options
Diffstat (limited to 'src/passes/DeNaN.cpp')
-rw-r--r-- | src/passes/DeNaN.cpp | 46 |
1 files changed, 45 insertions, 1 deletions
diff --git a/src/passes/DeNaN.cpp b/src/passes/DeNaN.cpp index 044c13c86..c894cb3f1 100644 --- a/src/passes/DeNaN.cpp +++ b/src/passes/DeNaN.cpp @@ -22,6 +22,7 @@ // differ on wasm's nondeterminism around NaNs. // +#include "ir/properties.h" #include "pass.h" #include "wasm-builder.h" #include "wasm.h" @@ -33,7 +34,18 @@ struct DeNaN : public WalkerPass< void visitExpression(Expression* expr) { // If the expression returns a floating-point value, ensure it is not a // NaN. If we can do this at compile time, do it now, which is useful for - // initializations of global (which we can't do a function call in). + // initializations of global (which we can't do a function call in). Note + // that we don't instrument local.gets, which would cause problems if we + // ran this pass more than once (the added functions use gets, and we don't + // want to instrument them). + if (expr->is<LocalGet>()) { + return; + } + // If the result just falls through without being modified, then we've + // already fixed it up earlier. + if (Properties::isResultFallthrough(expr)) { + return; + } Builder builder(*getModule()); Expression* replacement = nullptr; auto* c = expr->dynCast<Const>(); @@ -61,6 +73,38 @@ struct DeNaN : public WalkerPass< } } + void visitFunction(Function* func) { + if (func->imported()) { + return; + } + // Instrument all locals as they enter the function. + Builder builder(*getModule()); + std::vector<Expression*> fixes; + auto num = func->getNumParams(); + for (Index i = 0; i < num; i++) { + if (func->getLocalType(i) == Type::f32) { + fixes.push_back(builder.makeLocalSet( + i, + builder.makeCall( + "deNan32", {builder.makeLocalGet(i, Type::f32)}, Type::f32))); + } else if (func->getLocalType(i) == Type::f64) { + fixes.push_back(builder.makeLocalSet( + i, + builder.makeCall( + "deNan64", {builder.makeLocalGet(i, Type::f64)}, Type::f64))); + } + } + if (!fixes.empty()) { + fixes.push_back(func->body); + func->body = builder.makeBlock(fixes); + // Merge blocks so we don't add an unnecessary one. + PassRunner runner(getModule(), getPassOptions()); + runner.setIsNested(true); + runner.add("merge-blocks"); + runner.run(); + } + } + void visitModule(Module* module) { // Add helper functions. Builder builder(*module); |