summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/passes/CMakeLists.txt1
-rw-r--r--src/passes/LLVMNontrappingFPToIntLowering.cpp180
-rw-r--r--src/passes/pass.cpp4
-rw-r--r--src/passes/passes.h1
4 files changed, 186 insertions, 0 deletions
diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt
index c6e079079..e46163406 100644
--- a/src/passes/CMakeLists.txt
+++ b/src/passes/CMakeLists.txt
@@ -76,6 +76,7 @@ set(passes_SOURCES
NameList.cpp
NameTypes.cpp
NoInline.cpp
+ LLVMNontrappingFPToIntLowering.cpp
OnceReduction.cpp
OptimizeAddedConstants.cpp
OptimizeCasts.cpp
diff --git a/src/passes/LLVMNontrappingFPToIntLowering.cpp b/src/passes/LLVMNontrappingFPToIntLowering.cpp
new file mode 100644
index 000000000..d14e58af8
--- /dev/null
+++ b/src/passes/LLVMNontrappingFPToIntLowering.cpp
@@ -0,0 +1,180 @@
+#include "pass.h"
+#include "wasm-builder.h"
+#include "wasm.h"
+#include <limits>
+#include <memory>
+
+// By default LLVM emits nontrapping float-to-int instructions to implement its
+// fptoui/fptosi conversion instructions. This pass replaces these instructions
+// with code sequences which also implement LLVM's fptoui/fptosi, but which are
+// not semantically equivalent in wasm. This is because out-of-range inputs to
+// these instructions produce poison values. So we need only ensure that there
+// is no trap, but need not ensure any particular result. The transformation
+// in this pass is the same as the one used by LLVM to lower fptoui/fptosi
+// to wasm trapping instructions.
+
+// For example, if a conversion is guarded by a range check in the source, LLVM
+// can move the conversion before the check (and instead guard the use of the
+// result, which may be poison). This is valid in LLVM and for the nontrapping
+// wasm fptoint instructions but not for the trapping conversions. The
+// transformation in this pass is valid only if the nontrapping conversions
+// in the wasm were generated from LLVM and implement LLVM's conversion
+// semantics.
+
+namespace wasm {
+struct LLVMNonTrappingFPToIntLoweringImpl
+ : public WalkerPass<PostWalker<LLVMNonTrappingFPToIntLoweringImpl>> {
+ bool isFunctionParallel() override { return true; }
+
+ std::unique_ptr<Pass> create() override {
+ return std::make_unique<LLVMNonTrappingFPToIntLoweringImpl>();
+ }
+
+ UnaryOp getReplacementOp(UnaryOp op) {
+ switch (op) {
+ case TruncSatSFloat32ToInt32:
+ return TruncSFloat32ToInt32;
+ case TruncSatUFloat32ToInt32:
+ return TruncUFloat32ToInt32;
+ case TruncSatSFloat64ToInt32:
+ return TruncSFloat64ToInt32;
+ case TruncSatUFloat64ToInt32:
+ return TruncUFloat64ToInt32;
+ case TruncSatSFloat32ToInt64:
+ return TruncSFloat32ToInt64;
+ case TruncSatUFloat32ToInt64:
+ return TruncUFloat32ToInt64;
+ case TruncSatSFloat64ToInt64:
+ return TruncSFloat64ToInt64;
+ case TruncSatUFloat64ToInt64:
+ return TruncUFloat64ToInt64;
+ default:
+ WASM_UNREACHABLE("Unexpected opcode");
+ }
+ }
+
+ template<typename From, typename To> void replaceSigned(Unary* curr) {
+ BinaryOp ltOp;
+ UnaryOp absOp;
+ switch (curr->op) {
+ case TruncSatSFloat32ToInt32:
+ case TruncSatSFloat32ToInt64:
+ ltOp = LtFloat32;
+ absOp = AbsFloat32;
+ break;
+ case TruncSatSFloat64ToInt32:
+ case TruncSatSFloat64ToInt64:
+ ltOp = LtFloat64;
+ absOp = AbsFloat64;
+ break;
+ default:
+ WASM_UNREACHABLE("Unexpected opcode");
+ }
+
+ Builder builder(*getModule());
+ Index v = Builder::addVar(getFunction(), curr->value->type);
+ // if fabs(operand) < INT_MAX then use the trapping operation, else return
+ // INT_MIN. The altnernate value is correct for the case where the input is
+ // INT_MIN itself; otherwise it's UB so any value will do.
+ replaceCurrent(builder.makeIf(
+ builder.makeBinary(
+ ltOp,
+ builder.makeUnary(
+ absOp, builder.makeLocalTee(v, curr->value, curr->value->type)),
+ builder.makeConst(static_cast<From>(std::numeric_limits<To>::max()))),
+ builder.makeUnary(getReplacementOp(curr->op),
+ builder.makeLocalGet(v, curr->value->type)),
+ builder.makeConst(std::numeric_limits<To>::min())));
+ }
+
+ template<typename From, typename To> void replaceUnsigned(Unary* curr) {
+ BinaryOp ltOp, geOp;
+
+ switch (curr->op) {
+ case TruncSatUFloat32ToInt32:
+ case TruncSatUFloat32ToInt64:
+ ltOp = LtFloat32;
+ geOp = GeFloat32;
+ break;
+ case TruncSatUFloat64ToInt32:
+ case TruncSatUFloat64ToInt64:
+ ltOp = LtFloat64;
+ geOp = GeFloat64;
+ break;
+ default:
+ WASM_UNREACHABLE("Unexpected opcode");
+ }
+
+ Builder builder(*getModule());
+ Index v = Builder::addVar(getFunction(), curr->value->type);
+ // if op < INT_MAX and op >= 0 then use the trapping operation, else return
+ // 0
+ replaceCurrent(builder.makeIf(
+ builder.makeBinary(
+ AndInt32,
+ builder.makeBinary(
+ ltOp,
+ builder.makeLocalTee(v, curr->value, curr->value->type),
+ builder.makeConst(static_cast<From>(std::numeric_limits<To>::max()))),
+ builder.makeBinary(geOp,
+ builder.makeLocalGet(v, curr->value->type),
+ builder.makeConst(static_cast<From>(0.0)))),
+ builder.makeUnary(getReplacementOp(curr->op),
+ builder.makeLocalGet(v, curr->value->type)),
+ builder.makeConst(static_cast<To>(0))));
+ }
+
+ void visitUnary(Unary* curr) {
+ switch (curr->op) {
+ case TruncSatSFloat32ToInt32:
+ replaceSigned<float, int32_t>(curr);
+ break;
+ case TruncSatSFloat64ToInt32:
+ replaceSigned<double, int32_t>(curr);
+ break;
+ case TruncSatSFloat32ToInt64:
+ replaceSigned<float, int64_t>(curr);
+ break;
+ case TruncSatSFloat64ToInt64:
+ replaceSigned<double, int64_t>(curr);
+ break;
+ case TruncSatUFloat32ToInt32:
+ replaceUnsigned<float, uint32_t>(curr);
+ break;
+ case TruncSatUFloat64ToInt32:
+ replaceUnsigned<double, uint32_t>(curr);
+ break;
+ case TruncSatUFloat32ToInt64:
+ replaceUnsigned<float, uint64_t>(curr);
+ break;
+ case TruncSatUFloat64ToInt64:
+ replaceUnsigned<double, uint64_t>(curr);
+ break;
+ default:
+ break;
+ }
+ }
+
+ void doWalkFunction(Function* func) { Super::doWalkFunction(func); }
+};
+
+struct LLVMNonTrappingFPToIntLowering : public Pass {
+ void run(Module* module) override {
+ if (!module->features.hasTruncSat()) {
+ return;
+ }
+ PassRunner runner(module);
+ // Run the Impl pass as an inner pass in parallel. This pass updates the
+ // module features, so it can't be parallel.
+ runner.add(std::make_unique<LLVMNonTrappingFPToIntLoweringImpl>());
+ runner.setIsNested(true);
+ runner.run();
+ module->features.disable(FeatureSet::TruncSat);
+ }
+};
+
+Pass* createLLVMNonTrappingFPToIntLoweringPass() {
+ return new LLVMNonTrappingFPToIntLowering();
+}
+
+} // namespace wasm
diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp
index 5cbf4a31f..4be24cebf 100644
--- a/src/passes/pass.cpp
+++ b/src/passes/pass.cpp
@@ -334,6 +334,10 @@ void PassRegistry::registerPasses() {
registerPass("no-partial-inline",
"mark functions as no-inline (for partial inlining only)",
createNoPartialInlinePass);
+ registerPass("llvm-nontrapping-fptoint-lowering",
+ "lower nontrapping float-to-int operations to wasm mvp and "
+ "disable the nontrapping fptoint feature",
+ createLLVMNonTrappingFPToIntLoweringPass);
registerPass("once-reduction",
"reduces calls to code that only runs once",
createOnceReductionPass);
diff --git a/src/passes/passes.h b/src/passes/passes.h
index 212a2b0e4..aadd26d41 100644
--- a/src/passes/passes.h
+++ b/src/passes/passes.h
@@ -119,6 +119,7 @@ Pass* createOutliningPass();
Pass* createPickLoadSignsPass();
Pass* createModAsyncifyAlwaysOnlyUnwindPass();
Pass* createModAsyncifyNeverUnwindPass();
+Pass* createLLVMNonTrappingFPToIntLoweringPass();
Pass* createPoppifyPass();
Pass* createPostEmscriptenPass();
Pass* createPrecomputePass();