summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/ir/properties.h23
-rw-r--r--src/passes/DeNaN.cpp46
2 files changed, 67 insertions, 2 deletions
diff --git a/src/ir/properties.h b/src/ir/properties.h
index 094d90bd6..0c6824e4a 100644
--- a/src/ir/properties.h
+++ b/src/ir/properties.h
@@ -200,7 +200,9 @@ inline Index getZeroExtBits(Expression* curr) {
}
// Returns a falling-through value, that is, it looks through a local.tee
-// and other operations that receive a value and let it flow through them.
+// and other operations that receive a value and let it flow through them. If
+// there is no value falling through, returns the node itself (as that is the
+// value that trivially falls through, with 0 steps in the middle).
inline Expression* getFallthrough(Expression* curr,
const PassOptions& passOptions,
FeatureSet features) {
@@ -241,6 +243,25 @@ inline Expression* getFallthrough(Expression* curr,
return curr;
}
+// Returns whether the resulting value here must fall through without being
+// modified. For example, a tee always does so. That is, this returns false if
+// and only if the return value may have some computation performed on it to
+// change it from the inputs the instruction receives.
+// This differs from getFallthrough() which returns a single value that falls
+// through - here if more than one value can fall through, like in if-else,
+// we can return true. That is, there we care about a value falling through and
+// for us to get that actual value to look at; here we just care whether the
+// value falls through without being changed, even if it might be one of
+// several options.
+inline bool isResultFallthrough(Expression* curr) {
+ // Note that we don't check if there is a return value here; the node may be
+ // unreachable, for example, but then there is no meaningful answer to give
+ // anyhow.
+ return curr->is<LocalSet>() || curr->is<Block>() || curr->is<If>() ||
+ curr->is<Loop>() || curr->is<Try>() || curr->is<Select>() ||
+ curr->is<Break>();
+}
+
} // namespace Properties
} // namespace wasm
diff --git a/src/passes/DeNaN.cpp b/src/passes/DeNaN.cpp
index 044c13c86..c894cb3f1 100644
--- a/src/passes/DeNaN.cpp
+++ b/src/passes/DeNaN.cpp
@@ -22,6 +22,7 @@
// differ on wasm's nondeterminism around NaNs.
//
+#include "ir/properties.h"
#include "pass.h"
#include "wasm-builder.h"
#include "wasm.h"
@@ -33,7 +34,18 @@ struct DeNaN : public WalkerPass<
void visitExpression(Expression* expr) {
// If the expression returns a floating-point value, ensure it is not a
// NaN. If we can do this at compile time, do it now, which is useful for
- // initializations of global (which we can't do a function call in).
+ // initializations of global (which we can't do a function call in). Note
+ // that we don't instrument local.gets, which would cause problems if we
+ // ran this pass more than once (the added functions use gets, and we don't
+ // want to instrument them).
+ if (expr->is<LocalGet>()) {
+ return;
+ }
+ // If the result just falls through without being modified, then we've
+ // already fixed it up earlier.
+ if (Properties::isResultFallthrough(expr)) {
+ return;
+ }
Builder builder(*getModule());
Expression* replacement = nullptr;
auto* c = expr->dynCast<Const>();
@@ -61,6 +73,38 @@ struct DeNaN : public WalkerPass<
}
}
+ void visitFunction(Function* func) {
+ if (func->imported()) {
+ return;
+ }
+ // Instrument all locals as they enter the function.
+ Builder builder(*getModule());
+ std::vector<Expression*> fixes;
+ auto num = func->getNumParams();
+ for (Index i = 0; i < num; i++) {
+ if (func->getLocalType(i) == Type::f32) {
+ fixes.push_back(builder.makeLocalSet(
+ i,
+ builder.makeCall(
+ "deNan32", {builder.makeLocalGet(i, Type::f32)}, Type::f32)));
+ } else if (func->getLocalType(i) == Type::f64) {
+ fixes.push_back(builder.makeLocalSet(
+ i,
+ builder.makeCall(
+ "deNan64", {builder.makeLocalGet(i, Type::f64)}, Type::f64)));
+ }
+ }
+ if (!fixes.empty()) {
+ fixes.push_back(func->body);
+ func->body = builder.makeBlock(fixes);
+ // Merge blocks so we don't add an unnecessary one.
+ PassRunner runner(getModule(), getPassOptions());
+ runner.setIsNested(true);
+ runner.add("merge-blocks");
+ runner.run();
+ }
+ }
+
void visitModule(Module* module) {
// Add helper functions.
Builder builder(*module);