diff options
Diffstat (limited to 'src/passes/I64ToI32Lowering.cpp')
-rw-r--r-- | src/passes/I64ToI32Lowering.cpp | 252 |
1 files changed, 215 insertions, 37 deletions
diff --git a/src/passes/I64ToI32Lowering.cpp b/src/passes/I64ToI32Lowering.cpp index e6da9a1b8..2f6fdf122 100644 --- a/src/passes/I64ToI32Lowering.cpp +++ b/src/passes/I64ToI32Lowering.cpp @@ -40,10 +40,10 @@ static Name makeHighName(Name n) { struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> { struct TempVar { - TempVar(Index idx, I64ToI32Lowering& pass) : - idx(idx), pass(pass), moved(false) {} + TempVar(Index idx, Type ty, I64ToI32Lowering& pass) : + idx(idx), pass(pass), moved(false), ty(ty) {} - TempVar(TempVar&& other) : idx(other), pass(other.pass), moved(false) { + TempVar(TempVar&& other) : idx(other), pass(other.pass), moved(false), ty(other.ty) { assert(!other.moved); other.moved = true; } @@ -78,18 +78,17 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> { private: void freeIdx() { - assert(std::find(pass.freeTemps.begin(), pass.freeTemps.end(), idx) == - pass.freeTemps.end()); - pass.freeTemps.push_back(idx); + auto &freeList = pass.freeTemps[(int) ty]; + assert(std::find(freeList.begin(), freeList.end(), idx) == freeList.end()); + freeList.push_back(idx); } Index idx; I64ToI32Lowering& pass; bool moved; // since C++ will still destruct moved-from values + Type ty; }; - static Name highBitsGlobal; - // false since function types need to be lowered // TODO: allow module-level transformations in parallel passes bool isFunctionParallel() override { return false; } @@ -114,7 +113,7 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> { // to return the high 32 bits. auto* highBits = new Global(); highBits->type = i32; - highBits->name = highBitsGlobal; + highBits->name = INT64_TO_32_HIGH_BITS; highBits->init = builder->makeConst(Literal(int32_t(0))); highBits->mutable_ = true; module->addGlobal(highBits); @@ -159,7 +158,7 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> { BinaryOp secondOp = leftShift ? ShrUInt32 : ShlInt32; Block* equalRotateBlock = builder->blockify( builder->makeSetGlobal( - highBitsGlobal, + INT64_TO_32_HIGH_BITS, builder->makeGetLocal(lowBits, i32) ), builder->makeGetLocal(highBits, i32) @@ -182,7 +181,7 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> { ) ), builder->makeSetGlobal( - highBitsGlobal, + INT64_TO_32_HIGH_BITS, builder->makeBinary( OrInt32, builder->makeBinary( @@ -221,7 +220,7 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> { ) ), builder->makeSetGlobal( - highBitsGlobal, + INT64_TO_32_HIGH_BITS, builder->makeBinary( OrInt32, builder->makeBinary( @@ -328,18 +327,17 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> { func->body ); SetGlobal* setHigh = builder->makeSetGlobal( - highBitsGlobal, + INT64_TO_32_HIGH_BITS, builder->makeGetLocal(highBits, i32) ); GetLocal* getLow = builder->makeGetLocal(lowBits, i32); func->body = builder->blockify(setLow, setHigh, getLow); } } - assert(freeTemps.size() == nextTemp - func->getNumLocals()); int idx = 0; for (size_t i = func->getNumLocals(); i < nextTemp; i++) { Name tmpName("i64toi32_i32$" + std::to_string(idx++)); - builder->addVar(func, tmpName, i32); + builder->addVar(func, tmpName, tempTypes[i]); } } @@ -488,7 +486,7 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> { ); SetLocal* setHigh = builder->makeSetLocal( highBits, - builder->makeGetGlobal(highBitsGlobal, i32) + builder->makeGetGlobal(INT64_TO_32_HIGH_BITS, i32) ); GetLocal* getLow = builder->makeGetLocal(lowBits, i32); Block* result = builder->blockify(doCall, setHigh, getLow); @@ -733,15 +731,14 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> { void lowerReinterpretFloat64(Unary* curr) { // Assume that the wasm file assumes the address 0 is invalid and roundtrip // our f64 through memory at address 0 - Expression* zero = builder->makeConst(Literal(int32_t(0))); TempVar highBits = getTemp(); Block *result = builder->blockify( - builder->makeStore(8, 0, 8, zero, curr->value, f64), + builder->makeStore(8, 0, 8, builder->makeConst(Literal(int32_t(0))), curr->value, f64), builder->makeSetLocal( highBits, - builder->makeLoad(4, true, 4, 4, zero, i32) + builder->makeLoad(4, true, 4, 4, builder->makeConst(Literal(int32_t(0))), i32) ), - builder->makeLoad(4, true, 0, 4, zero, i32) + builder->makeLoad(4, true, 0, 4, builder->makeConst(Literal(int32_t(0))), i32) ); setOutParam(result, std::move(highBits)); replaceCurrent(result); @@ -751,13 +748,192 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> { // Assume that the wasm file assumes the address 0 is invalid and roundtrip // our i64 through memory at address 0 TempVar highBits = fetchOutParam(curr->value); - TempVar lowBits = getTemp(); - Expression* zero = builder->makeConst(Literal(int32_t(0))); Block *result = builder->blockify( - builder->makeStore(4, 0, 4, zero, curr->value, i32), - builder->makeStore(4, 4, 4, zero, builder->makeGetLocal(highBits, i32), i32), - builder->makeLoad(8, true, 0, 8, zero, f64) + builder->makeStore(4, 0, 4, builder->makeConst(Literal(int32_t(0))), curr->value, i32), + builder->makeStore(4, 4, 4, builder->makeConst(Literal(int32_t(0))), builder->makeGetLocal(highBits, i32), i32), + builder->makeLoad(8, true, 0, 8, builder->makeConst(Literal(int32_t(0))), f64) + ); + replaceCurrent(result); + } + + void lowerTruncFloatToInt(Unary *curr) { + // hiBits = if abs(f) >= 1.0 { + // if f > 0.0 { + // (unsigned) min( + // floor(f / (float) U32_MAX), + // (float) U32_MAX - 1, + // ) + // } else { + // (unsigned) ceil((f - (float) (unsigned) f) / ((float) U32_MAX)) + // } + // } else { + // 0 + // } + // + // loBits = (unsigned) f; + + Literal litZero, litOne, u32Max; + UnaryOp trunc, convert, abs, floor, ceil; + Type localType; + BinaryOp ge, gt, min, div, sub; + switch (curr->op) { + case TruncSFloat32ToInt64: + case TruncUFloat32ToInt64: { + litZero = Literal((float) 0); + litOne = Literal((float) 1); + u32Max = Literal(((float) UINT_MAX) + 1); + trunc = TruncUFloat32ToInt32; + convert = ConvertUInt32ToFloat32; + localType = f32; + abs = AbsFloat32; + ge = GeFloat32; + gt = GtFloat32; + min = MinFloat32; + floor = FloorFloat32; + ceil = CeilFloat32; + div = DivFloat32; + sub = SubFloat32; + break; + } + case TruncSFloat64ToInt64: + case TruncUFloat64ToInt64: { + litZero = Literal((double) 0); + litOne = Literal((double) 1); + u32Max = Literal(((double) UINT_MAX) + 1); + trunc = TruncUFloat64ToInt32; + convert = ConvertUInt32ToFloat64; + localType = f64; + abs = AbsFloat64; + ge = GeFloat64; + gt = GtFloat64; + min = MinFloat64; + floor = FloorFloat64; + ceil = CeilFloat64; + div = DivFloat64; + sub = SubFloat64; + break; + } + default: abort(); + } + + TempVar f = getTemp(localType); + TempVar highBits = getTemp(); + + Expression *gtZeroBranch = builder->makeBinary( + min, + builder->makeUnary( + floor, + builder->makeBinary( + div, + builder->makeGetLocal(f, localType), + builder->makeConst(u32Max) + ) + ), + builder->makeBinary(sub, builder->makeConst(u32Max), builder->makeConst(litOne)) + ); + Expression *ltZeroBranch = builder->makeUnary( + ceil, + builder->makeBinary( + div, + builder->makeBinary( + sub, + builder->makeGetLocal(f, localType), + builder->makeUnary(convert, + builder->makeUnary(trunc, builder->makeGetLocal(f, localType)) + ) + ), + builder->makeConst(u32Max) + ) + ); + + If *highBitsCalc = builder->makeIf( + builder->makeBinary( + gt, + builder-> makeGetLocal(f, localType), + builder->makeConst(litZero) + ), + builder->makeUnary(trunc, gtZeroBranch), + builder->makeUnary(trunc, ltZeroBranch) + ); + If *highBitsVal = builder->makeIf( + builder->makeBinary( + ge, + builder->makeUnary(abs, builder->makeGetLocal(f, localType)), + builder->makeConst(litOne) + ), + highBitsCalc, + builder->makeConst(Literal(int32_t(0))) + ); + Block *result = builder->blockify( + builder->makeSetLocal(f, curr->value), + builder->makeSetLocal(highBits, highBitsVal), + builder->makeUnary(trunc, builder->makeGetLocal(f, localType)) + ); + setOutParam(result, std::move(highBits)); + replaceCurrent(result); + } + + void lowerConvertIntToFloat(Unary *curr) { + // Here the same strategy as `emcc` is taken which takes the two halves of + // the 64-bit integer and creates a mathematical expression using float + // arithmetic to reassemble the final floating point value. + // + // For example for i64 -> f32 we generate: + // + // ((double) (unsigned) lowBits) + ((double) U32_MAX) * ((double) (int) highBits) + // + // Mostly just shuffling things around here with coercions and whatnot! + // Note though that all arithmetic is done with f64 to have as much + // precision as we can. + TempVar highBits = fetchOutParam(curr->value); + TempVar lowBits = getTemp(); + TempVar highResult = getTemp(); + + UnaryOp convertHigh; + switch (curr->op) { + case ConvertSInt64ToFloat32: + case ConvertSInt64ToFloat64: + convertHigh = ConvertSInt32ToFloat64; + break; + case ConvertUInt64ToFloat32: + case ConvertUInt64ToFloat64: + convertHigh = ConvertUInt32ToFloat64; + break; + default: abort(); + } + + Expression *result = builder->blockify( + builder->makeSetLocal(lowBits, curr->value), + builder->makeSetLocal( + highResult, + builder->makeConst(Literal(int32_t(0))) + ), + builder->makeBinary( + AddFloat64, + builder->makeUnary( + ConvertUInt32ToFloat64, + builder->makeGetLocal(lowBits, i32) + ), + builder->makeBinary( + MulFloat64, + builder->makeConst(Literal((double)UINT_MAX + 1)), + builder->makeUnary( + convertHigh, + builder->makeGetLocal(highBits, i32) + ) + ) + ) ); + + switch (curr->op) { + case ConvertSInt64ToFloat32: + case ConvertUInt64ToFloat32: { + result = builder->makeUnary(DemoteFloat64, result); + break; + } + default: break; + } + replaceCurrent(result); } @@ -889,11 +1065,11 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> { case TruncSFloat32ToInt64: case TruncUFloat32ToInt64: case TruncSFloat64ToInt64: - case TruncUFloat64ToInt64: + case TruncUFloat64ToInt64: lowerTruncFloatToInt(curr); break; case ConvertSInt64ToFloat32: case ConvertSInt64ToFloat64: case ConvertUInt64ToFloat32: - case ConvertUInt64ToFloat64: + case ConvertUInt64ToFloat64: lowerConvertIntToFloat(curr); break; default: std::cerr << "Unhandled unary operator: " << curr->op << std::endl; abort(); @@ -1386,7 +1562,7 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> { ), builder->makeSetLocal( rightHigh, - builder->makeGetGlobal(highBitsGlobal, i32) + builder->makeGetGlobal(INT64_TO_32_HIGH_BITS, i32) ), builder->makeGetLocal(lowResult, i32) ); @@ -1680,7 +1856,7 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> { TempVar highBits = fetchOutParam(curr->value); SetLocal* setLow = builder->makeSetLocal(lowBits, curr->value); SetGlobal* setHigh = builder->makeSetGlobal( - highBitsGlobal, + INT64_TO_32_HIGH_BITS, builder->makeGetLocal(highBits, i32) ); curr->value = builder->makeGetLocal(lowBits, i32); @@ -1693,20 +1869,24 @@ private: std::unordered_map<Index, Index> indexMap; std::unordered_map<Expression*, TempVar> highBitVars; std::unordered_map<Name, TempVar> labelHighBitVars; - std::vector<Index> freeTemps; + std::unordered_map<int, std::vector<Index>> freeTemps; + std::unordered_map<Index, Type> tempTypes; Index nextTemp; bool needRotl64 = false; bool needRotr64 = false; - TempVar getTemp() { + TempVar getTemp(Type ty = i32) { Index ret; - if (freeTemps.size() > 0) { - ret = freeTemps.back(); - freeTemps.pop_back(); + auto &freeList = freeTemps[(int) ty]; + if (freeList.size() > 0) { + ret = freeList.back(); + freeList.pop_back(); } else { ret = nextTemp++; + tempTypes[ret] = ty; } - return TempVar(ret, *this); + assert(tempTypes[ret] == ty); + return TempVar(ret, ty, *this); } bool hasOutParam(Expression* e) { @@ -1726,8 +1906,6 @@ private: } }; -Name I64ToI32Lowering::highBitsGlobal("i64toi32_i32$HIGH_BITS"); - Pass *createI64ToI32LoweringPass() { return new I64ToI32Lowering(); } |