summaryrefslogtreecommitdiff
path: root/src/passes/I64ToI32Lowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/passes/I64ToI32Lowering.cpp')
-rw-r--r--src/passes/I64ToI32Lowering.cpp252
1 files changed, 215 insertions, 37 deletions
diff --git a/src/passes/I64ToI32Lowering.cpp b/src/passes/I64ToI32Lowering.cpp
index e6da9a1b8..2f6fdf122 100644
--- a/src/passes/I64ToI32Lowering.cpp
+++ b/src/passes/I64ToI32Lowering.cpp
@@ -40,10 +40,10 @@ static Name makeHighName(Name n) {
struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> {
struct TempVar {
- TempVar(Index idx, I64ToI32Lowering& pass) :
- idx(idx), pass(pass), moved(false) {}
+ TempVar(Index idx, Type ty, I64ToI32Lowering& pass) :
+ idx(idx), pass(pass), moved(false), ty(ty) {}
- TempVar(TempVar&& other) : idx(other), pass(other.pass), moved(false) {
+ TempVar(TempVar&& other) : idx(other), pass(other.pass), moved(false), ty(other.ty) {
assert(!other.moved);
other.moved = true;
}
@@ -78,18 +78,17 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> {
private:
void freeIdx() {
- assert(std::find(pass.freeTemps.begin(), pass.freeTemps.end(), idx) ==
- pass.freeTemps.end());
- pass.freeTemps.push_back(idx);
+ auto &freeList = pass.freeTemps[(int) ty];
+ assert(std::find(freeList.begin(), freeList.end(), idx) == freeList.end());
+ freeList.push_back(idx);
}
Index idx;
I64ToI32Lowering& pass;
bool moved; // since C++ will still destruct moved-from values
+ Type ty;
};
- static Name highBitsGlobal;
-
// false since function types need to be lowered
// TODO: allow module-level transformations in parallel passes
bool isFunctionParallel() override { return false; }
@@ -114,7 +113,7 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> {
// to return the high 32 bits.
auto* highBits = new Global();
highBits->type = i32;
- highBits->name = highBitsGlobal;
+ highBits->name = INT64_TO_32_HIGH_BITS;
highBits->init = builder->makeConst(Literal(int32_t(0)));
highBits->mutable_ = true;
module->addGlobal(highBits);
@@ -159,7 +158,7 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> {
BinaryOp secondOp = leftShift ? ShrUInt32 : ShlInt32;
Block* equalRotateBlock = builder->blockify(
builder->makeSetGlobal(
- highBitsGlobal,
+ INT64_TO_32_HIGH_BITS,
builder->makeGetLocal(lowBits, i32)
),
builder->makeGetLocal(highBits, i32)
@@ -182,7 +181,7 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> {
)
),
builder->makeSetGlobal(
- highBitsGlobal,
+ INT64_TO_32_HIGH_BITS,
builder->makeBinary(
OrInt32,
builder->makeBinary(
@@ -221,7 +220,7 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> {
)
),
builder->makeSetGlobal(
- highBitsGlobal,
+ INT64_TO_32_HIGH_BITS,
builder->makeBinary(
OrInt32,
builder->makeBinary(
@@ -328,18 +327,17 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> {
func->body
);
SetGlobal* setHigh = builder->makeSetGlobal(
- highBitsGlobal,
+ INT64_TO_32_HIGH_BITS,
builder->makeGetLocal(highBits, i32)
);
GetLocal* getLow = builder->makeGetLocal(lowBits, i32);
func->body = builder->blockify(setLow, setHigh, getLow);
}
}
- assert(freeTemps.size() == nextTemp - func->getNumLocals());
int idx = 0;
for (size_t i = func->getNumLocals(); i < nextTemp; i++) {
Name tmpName("i64toi32_i32$" + std::to_string(idx++));
- builder->addVar(func, tmpName, i32);
+ builder->addVar(func, tmpName, tempTypes[i]);
}
}
@@ -488,7 +486,7 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> {
);
SetLocal* setHigh = builder->makeSetLocal(
highBits,
- builder->makeGetGlobal(highBitsGlobal, i32)
+ builder->makeGetGlobal(INT64_TO_32_HIGH_BITS, i32)
);
GetLocal* getLow = builder->makeGetLocal(lowBits, i32);
Block* result = builder->blockify(doCall, setHigh, getLow);
@@ -733,15 +731,14 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> {
void lowerReinterpretFloat64(Unary* curr) {
// Assume that the wasm file assumes the address 0 is invalid and roundtrip
// our f64 through memory at address 0
- Expression* zero = builder->makeConst(Literal(int32_t(0)));
TempVar highBits = getTemp();
Block *result = builder->blockify(
- builder->makeStore(8, 0, 8, zero, curr->value, f64),
+ builder->makeStore(8, 0, 8, builder->makeConst(Literal(int32_t(0))), curr->value, f64),
builder->makeSetLocal(
highBits,
- builder->makeLoad(4, true, 4, 4, zero, i32)
+ builder->makeLoad(4, true, 4, 4, builder->makeConst(Literal(int32_t(0))), i32)
),
- builder->makeLoad(4, true, 0, 4, zero, i32)
+ builder->makeLoad(4, true, 0, 4, builder->makeConst(Literal(int32_t(0))), i32)
);
setOutParam(result, std::move(highBits));
replaceCurrent(result);
@@ -751,13 +748,192 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> {
// Assume that the wasm file assumes the address 0 is invalid and roundtrip
// our i64 through memory at address 0
TempVar highBits = fetchOutParam(curr->value);
- TempVar lowBits = getTemp();
- Expression* zero = builder->makeConst(Literal(int32_t(0)));
Block *result = builder->blockify(
- builder->makeStore(4, 0, 4, zero, curr->value, i32),
- builder->makeStore(4, 4, 4, zero, builder->makeGetLocal(highBits, i32), i32),
- builder->makeLoad(8, true, 0, 8, zero, f64)
+ builder->makeStore(4, 0, 4, builder->makeConst(Literal(int32_t(0))), curr->value, i32),
+ builder->makeStore(4, 4, 4, builder->makeConst(Literal(int32_t(0))), builder->makeGetLocal(highBits, i32), i32),
+ builder->makeLoad(8, true, 0, 8, builder->makeConst(Literal(int32_t(0))), f64)
+ );
+ replaceCurrent(result);
+ }
+
+ void lowerTruncFloatToInt(Unary *curr) {
+ // hiBits = if abs(f) >= 1.0 {
+ // if f > 0.0 {
+ // (unsigned) min(
+ // floor(f / (float) U32_MAX),
+ // (float) U32_MAX - 1,
+ // )
+ // } else {
+ // (unsigned) ceil((f - (float) (unsigned) f) / ((float) U32_MAX))
+ // }
+ // } else {
+ // 0
+ // }
+ //
+ // loBits = (unsigned) f;
+
+ Literal litZero, litOne, u32Max;
+ UnaryOp trunc, convert, abs, floor, ceil;
+ Type localType;
+ BinaryOp ge, gt, min, div, sub;
+ switch (curr->op) {
+ case TruncSFloat32ToInt64:
+ case TruncUFloat32ToInt64: {
+ litZero = Literal((float) 0);
+ litOne = Literal((float) 1);
+ u32Max = Literal(((float) UINT_MAX) + 1);
+ trunc = TruncUFloat32ToInt32;
+ convert = ConvertUInt32ToFloat32;
+ localType = f32;
+ abs = AbsFloat32;
+ ge = GeFloat32;
+ gt = GtFloat32;
+ min = MinFloat32;
+ floor = FloorFloat32;
+ ceil = CeilFloat32;
+ div = DivFloat32;
+ sub = SubFloat32;
+ break;
+ }
+ case TruncSFloat64ToInt64:
+ case TruncUFloat64ToInt64: {
+ litZero = Literal((double) 0);
+ litOne = Literal((double) 1);
+ u32Max = Literal(((double) UINT_MAX) + 1);
+ trunc = TruncUFloat64ToInt32;
+ convert = ConvertUInt32ToFloat64;
+ localType = f64;
+ abs = AbsFloat64;
+ ge = GeFloat64;
+ gt = GtFloat64;
+ min = MinFloat64;
+ floor = FloorFloat64;
+ ceil = CeilFloat64;
+ div = DivFloat64;
+ sub = SubFloat64;
+ break;
+ }
+ default: abort();
+ }
+
+ TempVar f = getTemp(localType);
+ TempVar highBits = getTemp();
+
+ Expression *gtZeroBranch = builder->makeBinary(
+ min,
+ builder->makeUnary(
+ floor,
+ builder->makeBinary(
+ div,
+ builder->makeGetLocal(f, localType),
+ builder->makeConst(u32Max)
+ )
+ ),
+ builder->makeBinary(sub, builder->makeConst(u32Max), builder->makeConst(litOne))
+ );
+ Expression *ltZeroBranch = builder->makeUnary(
+ ceil,
+ builder->makeBinary(
+ div,
+ builder->makeBinary(
+ sub,
+ builder->makeGetLocal(f, localType),
+ builder->makeUnary(convert,
+ builder->makeUnary(trunc, builder->makeGetLocal(f, localType))
+ )
+ ),
+ builder->makeConst(u32Max)
+ )
+ );
+
+ If *highBitsCalc = builder->makeIf(
+ builder->makeBinary(
+ gt,
+ builder-> makeGetLocal(f, localType),
+ builder->makeConst(litZero)
+ ),
+ builder->makeUnary(trunc, gtZeroBranch),
+ builder->makeUnary(trunc, ltZeroBranch)
+ );
+ If *highBitsVal = builder->makeIf(
+ builder->makeBinary(
+ ge,
+ builder->makeUnary(abs, builder->makeGetLocal(f, localType)),
+ builder->makeConst(litOne)
+ ),
+ highBitsCalc,
+ builder->makeConst(Literal(int32_t(0)))
+ );
+ Block *result = builder->blockify(
+ builder->makeSetLocal(f, curr->value),
+ builder->makeSetLocal(highBits, highBitsVal),
+ builder->makeUnary(trunc, builder->makeGetLocal(f, localType))
+ );
+ setOutParam(result, std::move(highBits));
+ replaceCurrent(result);
+ }
+
+ void lowerConvertIntToFloat(Unary *curr) {
+ // Here the same strategy as `emcc` is taken which takes the two halves of
+ // the 64-bit integer and creates a mathematical expression using float
+ // arithmetic to reassemble the final floating point value.
+ //
+ // For example for i64 -> f32 we generate:
+ //
+ // ((double) (unsigned) lowBits) + ((double) U32_MAX) * ((double) (int) highBits)
+ //
+ // Mostly just shuffling things around here with coercions and whatnot!
+ // Note though that all arithmetic is done with f64 to have as much
+ // precision as we can.
+ TempVar highBits = fetchOutParam(curr->value);
+ TempVar lowBits = getTemp();
+ TempVar highResult = getTemp();
+
+ UnaryOp convertHigh;
+ switch (curr->op) {
+ case ConvertSInt64ToFloat32:
+ case ConvertSInt64ToFloat64:
+ convertHigh = ConvertSInt32ToFloat64;
+ break;
+ case ConvertUInt64ToFloat32:
+ case ConvertUInt64ToFloat64:
+ convertHigh = ConvertUInt32ToFloat64;
+ break;
+ default: abort();
+ }
+
+ Expression *result = builder->blockify(
+ builder->makeSetLocal(lowBits, curr->value),
+ builder->makeSetLocal(
+ highResult,
+ builder->makeConst(Literal(int32_t(0)))
+ ),
+ builder->makeBinary(
+ AddFloat64,
+ builder->makeUnary(
+ ConvertUInt32ToFloat64,
+ builder->makeGetLocal(lowBits, i32)
+ ),
+ builder->makeBinary(
+ MulFloat64,
+ builder->makeConst(Literal((double)UINT_MAX + 1)),
+ builder->makeUnary(
+ convertHigh,
+ builder->makeGetLocal(highBits, i32)
+ )
+ )
+ )
);
+
+ switch (curr->op) {
+ case ConvertSInt64ToFloat32:
+ case ConvertUInt64ToFloat32: {
+ result = builder->makeUnary(DemoteFloat64, result);
+ break;
+ }
+ default: break;
+ }
+
replaceCurrent(result);
}
@@ -889,11 +1065,11 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> {
case TruncSFloat32ToInt64:
case TruncUFloat32ToInt64:
case TruncSFloat64ToInt64:
- case TruncUFloat64ToInt64:
+ case TruncUFloat64ToInt64: lowerTruncFloatToInt(curr); break;
case ConvertSInt64ToFloat32:
case ConvertSInt64ToFloat64:
case ConvertUInt64ToFloat32:
- case ConvertUInt64ToFloat64:
+ case ConvertUInt64ToFloat64: lowerConvertIntToFloat(curr); break;
default:
std::cerr << "Unhandled unary operator: " << curr->op << std::endl;
abort();
@@ -1386,7 +1562,7 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> {
),
builder->makeSetLocal(
rightHigh,
- builder->makeGetGlobal(highBitsGlobal, i32)
+ builder->makeGetGlobal(INT64_TO_32_HIGH_BITS, i32)
),
builder->makeGetLocal(lowResult, i32)
);
@@ -1680,7 +1856,7 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> {
TempVar highBits = fetchOutParam(curr->value);
SetLocal* setLow = builder->makeSetLocal(lowBits, curr->value);
SetGlobal* setHigh = builder->makeSetGlobal(
- highBitsGlobal,
+ INT64_TO_32_HIGH_BITS,
builder->makeGetLocal(highBits, i32)
);
curr->value = builder->makeGetLocal(lowBits, i32);
@@ -1693,20 +1869,24 @@ private:
std::unordered_map<Index, Index> indexMap;
std::unordered_map<Expression*, TempVar> highBitVars;
std::unordered_map<Name, TempVar> labelHighBitVars;
- std::vector<Index> freeTemps;
+ std::unordered_map<int, std::vector<Index>> freeTemps;
+ std::unordered_map<Index, Type> tempTypes;
Index nextTemp;
bool needRotl64 = false;
bool needRotr64 = false;
- TempVar getTemp() {
+ TempVar getTemp(Type ty = i32) {
Index ret;
- if (freeTemps.size() > 0) {
- ret = freeTemps.back();
- freeTemps.pop_back();
+ auto &freeList = freeTemps[(int) ty];
+ if (freeList.size() > 0) {
+ ret = freeList.back();
+ freeList.pop_back();
} else {
ret = nextTemp++;
+ tempTypes[ret] = ty;
}
- return TempVar(ret, *this);
+ assert(tempTypes[ret] == ty);
+ return TempVar(ret, ty, *this);
}
bool hasOutParam(Expression* e) {
@@ -1726,8 +1906,6 @@ private:
}
};
-Name I64ToI32Lowering::highBitsGlobal("i64toi32_i32$HIGH_BITS");
-
Pass *createI64ToI32LoweringPass() {
return new I64ToI32Lowering();
}