diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/passes/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/passes/LLVMNontrappingFPToIntLowering.cpp | 180 | ||||
-rw-r--r-- | src/passes/pass.cpp | 4 | ||||
-rw-r--r-- | src/passes/passes.h | 1 |
4 files changed, 186 insertions, 0 deletions
diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt index c6e079079..e46163406 100644 --- a/src/passes/CMakeLists.txt +++ b/src/passes/CMakeLists.txt @@ -76,6 +76,7 @@ set(passes_SOURCES NameList.cpp NameTypes.cpp NoInline.cpp + LLVMNontrappingFPToIntLowering.cpp OnceReduction.cpp OptimizeAddedConstants.cpp OptimizeCasts.cpp diff --git a/src/passes/LLVMNontrappingFPToIntLowering.cpp b/src/passes/LLVMNontrappingFPToIntLowering.cpp new file mode 100644 index 000000000..d14e58af8 --- /dev/null +++ b/src/passes/LLVMNontrappingFPToIntLowering.cpp @@ -0,0 +1,180 @@ +#include "pass.h" +#include "wasm-builder.h" +#include "wasm.h" +#include <limits> +#include <memory> + +// By default LLVM emits nontrapping float-to-int instructions to implement its +// fptoui/fptosi conversion instructions. This pass replaces these instructions +// with code sequences which also implement LLVM's fptoui/fptosi, but which are +// not semantically equivalent in wasm. This is because out-of-range inputs to +// these instructions produce poison values. So we need only ensure that there +// is no trap, but need not ensure any particular result. The transformation +// in this pass is the same as the one used by LLVM to lower fptoui/fptosi +// to wasm trapping instructions. + +// For example, if a conversion is guarded by a range check in the source, LLVM +// can move the conversion before the check (and instead guard the use of the +// result, which may be poison). This is valid in LLVM and for the nontrapping +// wasm fptoint instructions but not for the trapping conversions. The +// transformation in this pass is valid only if the nontrapping conversions +// in the wasm were generated from LLVM and implement LLVM's conversion +// semantics. + +namespace wasm { +struct LLVMNonTrappingFPToIntLoweringImpl + : public WalkerPass<PostWalker<LLVMNonTrappingFPToIntLoweringImpl>> { + bool isFunctionParallel() override { return true; } + + std::unique_ptr<Pass> create() override { + return std::make_unique<LLVMNonTrappingFPToIntLoweringImpl>(); + } + + UnaryOp getReplacementOp(UnaryOp op) { + switch (op) { + case TruncSatSFloat32ToInt32: + return TruncSFloat32ToInt32; + case TruncSatUFloat32ToInt32: + return TruncUFloat32ToInt32; + case TruncSatSFloat64ToInt32: + return TruncSFloat64ToInt32; + case TruncSatUFloat64ToInt32: + return TruncUFloat64ToInt32; + case TruncSatSFloat32ToInt64: + return TruncSFloat32ToInt64; + case TruncSatUFloat32ToInt64: + return TruncUFloat32ToInt64; + case TruncSatSFloat64ToInt64: + return TruncSFloat64ToInt64; + case TruncSatUFloat64ToInt64: + return TruncUFloat64ToInt64; + default: + WASM_UNREACHABLE("Unexpected opcode"); + } + } + + template<typename From, typename To> void replaceSigned(Unary* curr) { + BinaryOp ltOp; + UnaryOp absOp; + switch (curr->op) { + case TruncSatSFloat32ToInt32: + case TruncSatSFloat32ToInt64: + ltOp = LtFloat32; + absOp = AbsFloat32; + break; + case TruncSatSFloat64ToInt32: + case TruncSatSFloat64ToInt64: + ltOp = LtFloat64; + absOp = AbsFloat64; + break; + default: + WASM_UNREACHABLE("Unexpected opcode"); + } + + Builder builder(*getModule()); + Index v = Builder::addVar(getFunction(), curr->value->type); + // if fabs(operand) < INT_MAX then use the trapping operation, else return + // INT_MIN. The altnernate value is correct for the case where the input is + // INT_MIN itself; otherwise it's UB so any value will do. + replaceCurrent(builder.makeIf( + builder.makeBinary( + ltOp, + builder.makeUnary( + absOp, builder.makeLocalTee(v, curr->value, curr->value->type)), + builder.makeConst(static_cast<From>(std::numeric_limits<To>::max()))), + builder.makeUnary(getReplacementOp(curr->op), + builder.makeLocalGet(v, curr->value->type)), + builder.makeConst(std::numeric_limits<To>::min()))); + } + + template<typename From, typename To> void replaceUnsigned(Unary* curr) { + BinaryOp ltOp, geOp; + + switch (curr->op) { + case TruncSatUFloat32ToInt32: + case TruncSatUFloat32ToInt64: + ltOp = LtFloat32; + geOp = GeFloat32; + break; + case TruncSatUFloat64ToInt32: + case TruncSatUFloat64ToInt64: + ltOp = LtFloat64; + geOp = GeFloat64; + break; + default: + WASM_UNREACHABLE("Unexpected opcode"); + } + + Builder builder(*getModule()); + Index v = Builder::addVar(getFunction(), curr->value->type); + // if op < INT_MAX and op >= 0 then use the trapping operation, else return + // 0 + replaceCurrent(builder.makeIf( + builder.makeBinary( + AndInt32, + builder.makeBinary( + ltOp, + builder.makeLocalTee(v, curr->value, curr->value->type), + builder.makeConst(static_cast<From>(std::numeric_limits<To>::max()))), + builder.makeBinary(geOp, + builder.makeLocalGet(v, curr->value->type), + builder.makeConst(static_cast<From>(0.0)))), + builder.makeUnary(getReplacementOp(curr->op), + builder.makeLocalGet(v, curr->value->type)), + builder.makeConst(static_cast<To>(0)))); + } + + void visitUnary(Unary* curr) { + switch (curr->op) { + case TruncSatSFloat32ToInt32: + replaceSigned<float, int32_t>(curr); + break; + case TruncSatSFloat64ToInt32: + replaceSigned<double, int32_t>(curr); + break; + case TruncSatSFloat32ToInt64: + replaceSigned<float, int64_t>(curr); + break; + case TruncSatSFloat64ToInt64: + replaceSigned<double, int64_t>(curr); + break; + case TruncSatUFloat32ToInt32: + replaceUnsigned<float, uint32_t>(curr); + break; + case TruncSatUFloat64ToInt32: + replaceUnsigned<double, uint32_t>(curr); + break; + case TruncSatUFloat32ToInt64: + replaceUnsigned<float, uint64_t>(curr); + break; + case TruncSatUFloat64ToInt64: + replaceUnsigned<double, uint64_t>(curr); + break; + default: + break; + } + } + + void doWalkFunction(Function* func) { Super::doWalkFunction(func); } +}; + +struct LLVMNonTrappingFPToIntLowering : public Pass { + void run(Module* module) override { + if (!module->features.hasTruncSat()) { + return; + } + PassRunner runner(module); + // Run the Impl pass as an inner pass in parallel. This pass updates the + // module features, so it can't be parallel. + runner.add(std::make_unique<LLVMNonTrappingFPToIntLoweringImpl>()); + runner.setIsNested(true); + runner.run(); + module->features.disable(FeatureSet::TruncSat); + } +}; + +Pass* createLLVMNonTrappingFPToIntLoweringPass() { + return new LLVMNonTrappingFPToIntLowering(); +} + +} // namespace wasm diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index 5cbf4a31f..4be24cebf 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -334,6 +334,10 @@ void PassRegistry::registerPasses() { registerPass("no-partial-inline", "mark functions as no-inline (for partial inlining only)", createNoPartialInlinePass); + registerPass("llvm-nontrapping-fptoint-lowering", + "lower nontrapping float-to-int operations to wasm mvp and " + "disable the nontrapping fptoint feature", + createLLVMNonTrappingFPToIntLoweringPass); registerPass("once-reduction", "reduces calls to code that only runs once", createOnceReductionPass); diff --git a/src/passes/passes.h b/src/passes/passes.h index 212a2b0e4..aadd26d41 100644 --- a/src/passes/passes.h +++ b/src/passes/passes.h @@ -119,6 +119,7 @@ Pass* createOutliningPass(); Pass* createPickLoadSignsPass(); Pass* createModAsyncifyAlwaysOnlyUnwindPass(); Pass* createModAsyncifyNeverUnwindPass(); +Pass* createLLVMNonTrappingFPToIntLoweringPass(); Pass* createPoppifyPass(); Pass* createPostEmscriptenPass(); Pass* createPrecomputePass(); |