summaryrefslogtreecommitdiff
path: root/src/passes/LLVMNontrappingFPToIntLowering.cpp
blob: d14e58af8063a78a2f1232aa2fe10047140f4858 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
#include "pass.h"
#include "wasm-builder.h"
#include "wasm.h"
#include <limits>
#include <memory>

// By default LLVM emits nontrapping float-to-int instructions to implement its
// fptoui/fptosi conversion instructions. This pass replaces these instructions
// with code sequences which also implement LLVM's fptoui/fptosi, but which are
// not semantically equivalent in wasm. This is because out-of-range inputs to
// these instructions produce poison values. So we need only ensure that there
// is no trap, but need not ensure any particular result. The transformation
// in this pass is the same as the one used by LLVM to lower fptoui/fptosi
// to wasm trapping instructions.

// For example, if a conversion is guarded by a range check in the source, LLVM
// can move the conversion before the check (and instead guard the use of the
// result, which may be poison). This is valid in LLVM and for the nontrapping
// wasm fptoint instructions but not for the trapping conversions. The
// transformation in this pass is valid only if the nontrapping conversions
// in the wasm were generated from LLVM and implement LLVM's conversion
// semantics.

namespace wasm {
struct LLVMNonTrappingFPToIntLoweringImpl
  : public WalkerPass<PostWalker<LLVMNonTrappingFPToIntLoweringImpl>> {
  bool isFunctionParallel() override { return true; }

  std::unique_ptr<Pass> create() override {
    return std::make_unique<LLVMNonTrappingFPToIntLoweringImpl>();
  }

  UnaryOp getReplacementOp(UnaryOp op) {
    switch (op) {
      case TruncSatSFloat32ToInt32:
        return TruncSFloat32ToInt32;
      case TruncSatUFloat32ToInt32:
        return TruncUFloat32ToInt32;
      case TruncSatSFloat64ToInt32:
        return TruncSFloat64ToInt32;
      case TruncSatUFloat64ToInt32:
        return TruncUFloat64ToInt32;
      case TruncSatSFloat32ToInt64:
        return TruncSFloat32ToInt64;
      case TruncSatUFloat32ToInt64:
        return TruncUFloat32ToInt64;
      case TruncSatSFloat64ToInt64:
        return TruncSFloat64ToInt64;
      case TruncSatUFloat64ToInt64:
        return TruncUFloat64ToInt64;
      default:
        WASM_UNREACHABLE("Unexpected opcode");
    }
  }

  template<typename From, typename To> void replaceSigned(Unary* curr) {
    BinaryOp ltOp;
    UnaryOp absOp;
    switch (curr->op) {
      case TruncSatSFloat32ToInt32:
      case TruncSatSFloat32ToInt64:
        ltOp = LtFloat32;
        absOp = AbsFloat32;
        break;
      case TruncSatSFloat64ToInt32:
      case TruncSatSFloat64ToInt64:
        ltOp = LtFloat64;
        absOp = AbsFloat64;
        break;
      default:
        WASM_UNREACHABLE("Unexpected opcode");
    }

    Builder builder(*getModule());
    Index v = Builder::addVar(getFunction(), curr->value->type);
    // if fabs(operand) < INT_MAX then use the trapping operation, else return
    // INT_MIN. The altnernate value is correct for the case where the input is
    // INT_MIN itself; otherwise it's UB so any value will do.
    replaceCurrent(builder.makeIf(
      builder.makeBinary(
        ltOp,
        builder.makeUnary(
          absOp, builder.makeLocalTee(v, curr->value, curr->value->type)),
        builder.makeConst(static_cast<From>(std::numeric_limits<To>::max()))),
      builder.makeUnary(getReplacementOp(curr->op),
                        builder.makeLocalGet(v, curr->value->type)),
      builder.makeConst(std::numeric_limits<To>::min())));
  }

  template<typename From, typename To> void replaceUnsigned(Unary* curr) {
    BinaryOp ltOp, geOp;

    switch (curr->op) {
      case TruncSatUFloat32ToInt32:
      case TruncSatUFloat32ToInt64:
        ltOp = LtFloat32;
        geOp = GeFloat32;
        break;
      case TruncSatUFloat64ToInt32:
      case TruncSatUFloat64ToInt64:
        ltOp = LtFloat64;
        geOp = GeFloat64;
        break;
      default:
        WASM_UNREACHABLE("Unexpected opcode");
    }

    Builder builder(*getModule());
    Index v = Builder::addVar(getFunction(), curr->value->type);
    // if op < INT_MAX and op >= 0 then use the trapping operation, else return
    // 0
    replaceCurrent(builder.makeIf(
      builder.makeBinary(
        AndInt32,
        builder.makeBinary(
          ltOp,
          builder.makeLocalTee(v, curr->value, curr->value->type),
          builder.makeConst(static_cast<From>(std::numeric_limits<To>::max()))),
        builder.makeBinary(geOp,
                           builder.makeLocalGet(v, curr->value->type),
                           builder.makeConst(static_cast<From>(0.0)))),
      builder.makeUnary(getReplacementOp(curr->op),
                        builder.makeLocalGet(v, curr->value->type)),
      builder.makeConst(static_cast<To>(0))));
  }

  void visitUnary(Unary* curr) {
    switch (curr->op) {
      case TruncSatSFloat32ToInt32:
        replaceSigned<float, int32_t>(curr);
        break;
      case TruncSatSFloat64ToInt32:
        replaceSigned<double, int32_t>(curr);
        break;
      case TruncSatSFloat32ToInt64:
        replaceSigned<float, int64_t>(curr);
        break;
      case TruncSatSFloat64ToInt64:
        replaceSigned<double, int64_t>(curr);
        break;
      case TruncSatUFloat32ToInt32:
        replaceUnsigned<float, uint32_t>(curr);
        break;
      case TruncSatUFloat64ToInt32:
        replaceUnsigned<double, uint32_t>(curr);
        break;
      case TruncSatUFloat32ToInt64:
        replaceUnsigned<float, uint64_t>(curr);
        break;
      case TruncSatUFloat64ToInt64:
        replaceUnsigned<double, uint64_t>(curr);
        break;
      default:
        break;
    }
  }

  void doWalkFunction(Function* func) { Super::doWalkFunction(func); }
};

struct LLVMNonTrappingFPToIntLowering : public Pass {
  void run(Module* module) override {
    if (!module->features.hasTruncSat()) {
      return;
    }
    PassRunner runner(module);
    // Run the Impl pass as an inner pass in parallel. This pass updates the
    // module features, so it can't be parallel.
    runner.add(std::make_unique<LLVMNonTrappingFPToIntLoweringImpl>());
    runner.setIsNested(true);
    runner.run();
    module->features.disable(FeatureSet::TruncSat);
  }
};

Pass* createLLVMNonTrappingFPToIntLoweringPass() {
  return new LLVMNonTrappingFPToIntLowering();
}

} // namespace wasm