summaryrefslogtreecommitdiff
path: root/src/passes/DeNaN.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/passes/DeNaN.cpp')
-rw-r--r--src/passes/DeNaN.cpp46
1 files changed, 45 insertions, 1 deletions
diff --git a/src/passes/DeNaN.cpp b/src/passes/DeNaN.cpp
index 044c13c86..c894cb3f1 100644
--- a/src/passes/DeNaN.cpp
+++ b/src/passes/DeNaN.cpp
@@ -22,6 +22,7 @@
// differ on wasm's nondeterminism around NaNs.
//
+#include "ir/properties.h"
#include "pass.h"
#include "wasm-builder.h"
#include "wasm.h"
@@ -33,7 +34,18 @@ struct DeNaN : public WalkerPass<
void visitExpression(Expression* expr) {
// If the expression returns a floating-point value, ensure it is not a
// NaN. If we can do this at compile time, do it now, which is useful for
- // initializations of global (which we can't do a function call in).
+ // initializations of global (which we can't do a function call in). Note
+ // that we don't instrument local.gets, which would cause problems if we
+ // ran this pass more than once (the added functions use gets, and we don't
+ // want to instrument them).
+ if (expr->is<LocalGet>()) {
+ return;
+ }
+ // If the result just falls through without being modified, then we've
+ // already fixed it up earlier.
+ if (Properties::isResultFallthrough(expr)) {
+ return;
+ }
Builder builder(*getModule());
Expression* replacement = nullptr;
auto* c = expr->dynCast<Const>();
@@ -61,6 +73,38 @@ struct DeNaN : public WalkerPass<
}
}
+ void visitFunction(Function* func) {
+ if (func->imported()) {
+ return;
+ }
+ // Instrument all locals as they enter the function.
+ Builder builder(*getModule());
+ std::vector<Expression*> fixes;
+ auto num = func->getNumParams();
+ for (Index i = 0; i < num; i++) {
+ if (func->getLocalType(i) == Type::f32) {
+ fixes.push_back(builder.makeLocalSet(
+ i,
+ builder.makeCall(
+ "deNan32", {builder.makeLocalGet(i, Type::f32)}, Type::f32)));
+ } else if (func->getLocalType(i) == Type::f64) {
+ fixes.push_back(builder.makeLocalSet(
+ i,
+ builder.makeCall(
+ "deNan64", {builder.makeLocalGet(i, Type::f64)}, Type::f64)));
+ }
+ }
+ if (!fixes.empty()) {
+ fixes.push_back(func->body);
+ func->body = builder.makeBlock(fixes);
+ // Merge blocks so we don't add an unnecessary one.
+ PassRunner runner(getModule(), getPassOptions());
+ runner.setIsNested(true);
+ runner.add("merge-blocks");
+ runner.run();
+ }
+ }
+
void visitModule(Module* module) {
// Add helper functions.
Builder builder(*module);