summaryrefslogtreecommitdiff
path: root/src/ir/bits.h
blob: e0bca8d879baf8d09bcef854595434b9669a264f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
/*
 * Copyright 2017 WebAssembly Community Group participants
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef wasm_ir_bits_h
#define wasm_ir_bits_h

#include "ir/literal-utils.h"
#include "support/bits.h"
#include "wasm-builder.h"
#include <ir/load-utils.h>

namespace wasm {

namespace Bits {

// get a mask to keep only the low # of bits
inline int32_t lowBitMask(int32_t bits) {
  uint32_t ret = -1;
  if (bits >= 32) {
    return ret;
  }
  return ret >> (32 - bits);
}

// checks if the input is a mask of lower bits, i.e., all 1s up to some high
// bit, and all zeros from there. returns the number of masked bits, or 0 if
// this is not such a mask
inline uint32_t getMaskedBits(uint32_t mask) {
  if (mask == uint32_t(-1)) {
    return 32; // all the bits
  }
  if (mask == 0) {
    return 0; // trivially not a mask
  }
  // otherwise, see if x & (x + 1) turns this into non-zero value
  // 00011111 & (00011111 + 1) => 0
  if (mask & (mask + 1)) {
    return 0;
  }
  // this is indeed a mask
  return 32 - CountLeadingZeroes(mask);
}

// gets the number of effective shifts a shift operation does. In
// wasm, only 5 bits matter for 32-bit shifts, and 6 for 64.
inline Index getEffectiveShifts(Index amount, Type type) {
  if (type == Type::i32) {
    return amount & 31;
  } else if (type == Type::i64) {
    return amount & 63;
  }
  WASM_UNREACHABLE("unexpected type");
}

inline Index getEffectiveShifts(Expression* expr) {
  auto* amount = expr->cast<Const>();
  if (amount->type == Type::i32) {
    return getEffectiveShifts(amount->value.geti32(), Type::i32);
  } else if (amount->type == Type::i64) {
    return getEffectiveShifts(amount->value.geti64(), Type::i64);
  }
  WASM_UNREACHABLE("unexpected type");
}

inline Expression* makeSignExt(Expression* value, Index bytes, Module& wasm) {
  if (value->type == Type::i32) {
    if (bytes == 1 || bytes == 2) {
      auto shifts = bytes == 1 ? 24 : 16;
      Builder builder(wasm);
      return builder.makeBinary(
        ShrSInt32,
        builder.makeBinary(
          ShlInt32,
          value,
          LiteralUtils::makeFromInt32(shifts, Type::i32, wasm)),
        LiteralUtils::makeFromInt32(shifts, Type::i32, wasm));
    }
    assert(bytes == 4);
    return value; // nothing to do
  } else {
    assert(value->type == Type::i64);
    if (bytes == 1 || bytes == 2 || bytes == 4) {
      auto shifts = bytes == 1 ? 56 : (bytes == 2 ? 48 : 32);
      Builder builder(wasm);
      return builder.makeBinary(
        ShrSInt64,
        builder.makeBinary(
          ShlInt64,
          value,
          LiteralUtils::makeFromInt32(shifts, Type::i64, wasm)),
        LiteralUtils::makeFromInt32(shifts, Type::i64, wasm));
    }
    assert(bytes == 8);
    return value; // nothing to do
  }
}

// getMaxBits() helper that has pessimistic results for the bits used in locals.
struct DummyLocalInfoProvider {
  Index getMaxBitsForLocal(LocalGet* get) {
    if (get->type == Type::i32) {
      return 32;
    }
    if (get->type == Type::i32) {
      return 64;
    }
    WASM_UNREACHABLE("type has no integer bit size");
  }
};

// Returns the maximum amount of bits used in an integer expression
// not extremely precise (doesn't look into add operands, etc.)
// LocalInfoProvider is an optional class that can provide answers about
// local.get.
template<typename LocalInfoProvider = DummyLocalInfoProvider>
Index getMaxBits(Expression* curr,
                 LocalInfoProvider* localInfoProvider = nullptr) {
  if (auto* const_ = curr->dynCast<Const>()) {
    switch (curr->type.getSingle()) {
      case Type::i32:
        return 32 - const_->value.countLeadingZeroes().geti32();
      case Type::i64:
        return 64 - const_->value.countLeadingZeroes().geti64();
      default:
        WASM_UNREACHABLE("invalid type");
    }
  } else if (auto* binary = curr->dynCast<Binary>()) {
    switch (binary->op) {
      // 32-bit
      case AddInt32:
      case SubInt32:
      case MulInt32:
      case DivSInt32:
      case DivUInt32:
      case RemSInt32:
      case RemUInt32:
      case RotLInt32:
      case RotRInt32:
        return 32;
      case AndInt32:
        return std::min(getMaxBits(binary->left, localInfoProvider),
                        getMaxBits(binary->right, localInfoProvider));
      case OrInt32:
      case XorInt32:
        return std::max(getMaxBits(binary->left, localInfoProvider),
                        getMaxBits(binary->right, localInfoProvider));
      case ShlInt32: {
        if (auto* shifts = binary->right->dynCast<Const>()) {
          return std::min(Index(32),
                          getMaxBits(binary->left, localInfoProvider) +
                            Bits::getEffectiveShifts(shifts));
        }
        return 32;
      }
      case ShrUInt32: {
        if (auto* shift = binary->right->dynCast<Const>()) {
          auto maxBits = getMaxBits(binary->left, localInfoProvider);
          auto shifts =
            std::min(Index(Bits::getEffectiveShifts(shift)),
                     maxBits); // can ignore more shifts than zero us out
          return std::max(Index(0), maxBits - shifts);
        }
        return 32;
      }
      case ShrSInt32: {
        if (auto* shift = binary->right->dynCast<Const>()) {
          auto maxBits = getMaxBits(binary->left, localInfoProvider);
          if (maxBits == 32) {
            return 32;
          }
          auto shifts =
            std::min(Index(Bits::getEffectiveShifts(shift)),
                     maxBits); // can ignore more shifts than zero us out
          return std::max(Index(0), maxBits - shifts);
        }
        return 32;
      }
      // 64-bit TODO
      // comparisons
      case EqInt32:
      case NeInt32:
      case LtSInt32:
      case LtUInt32:
      case LeSInt32:
      case LeUInt32:
      case GtSInt32:
      case GtUInt32:
      case GeSInt32:
      case GeUInt32:
      case EqInt64:
      case NeInt64:
      case LtSInt64:
      case LtUInt64:
      case LeSInt64:
      case LeUInt64:
      case GtSInt64:
      case GtUInt64:
      case GeSInt64:
      case GeUInt64:
      case EqFloat32:
      case NeFloat32:
      case LtFloat32:
      case LeFloat32:
      case GtFloat32:
      case GeFloat32:
      case EqFloat64:
      case NeFloat64:
      case LtFloat64:
      case LeFloat64:
      case GtFloat64:
      case GeFloat64:
        return 1;
      default: {
      }
    }
  } else if (auto* unary = curr->dynCast<Unary>()) {
    switch (unary->op) {
      case ClzInt32:
      case CtzInt32:
      case PopcntInt32:
        return 6;
      case ClzInt64:
      case CtzInt64:
      case PopcntInt64:
        return 7;
      case EqZInt32:
      case EqZInt64:
        return 1;
      case WrapInt64:
        return std::min(Index(32), getMaxBits(unary->value, localInfoProvider));
      default: {
      }
    }
  } else if (auto* set = curr->dynCast<LocalSet>()) {
    // a tee passes through the value
    return getMaxBits(set->value, localInfoProvider);
  } else if (auto* get = curr->dynCast<LocalGet>()) {
    return localInfoProvider->getMaxBitsForLocal(get);
  } else if (auto* load = curr->dynCast<Load>()) {
    // if signed, then the sign-extension might fill all the bits
    // if unsigned, then we have a limit
    if (LoadUtils::isSignRelevant(load) && !load->signed_) {
      return 8 * load->bytes;
    }
  }
  switch (curr->type.getSingle()) {
    case Type::i32:
      return 32;
    case Type::i64:
      return 64;
    case Type::unreachable:
      return 64; // not interesting, but don't crash
    default:
      WASM_UNREACHABLE("invalid type");
  }
}

} // namespace Bits

} // namespace wasm

#endif // wasm_ir_bits_h