diff options
author | Thomas Lively <7121787+tlively@users.noreply.github.com> | 2017-09-01 14:26:01 -0400 |
---|---|---|
committer | Alon Zakai <alonzakai@gmail.com> | 2017-09-01 11:26:01 -0700 |
commit | b013f744e3d70effd9be348cbde7fb93f0a16c6a (patch) | |
tree | 3b122293005d3370c931175eed92ad61e9dfd851 /src/passes/I64ToI32Lowering.cpp | |
parent | b1e8b1b515b2a1d0264975abc4de39c8044f7195 (diff) | |
download | binaryen-b013f744e3d70effd9be348cbde7fb93f0a16c6a.tar.gz binaryen-b013f744e3d70effd9be348cbde7fb93f0a16c6a.tar.bz2 binaryen-b013f744e3d70effd9be348cbde7fb93f0a16c6a.zip |
i64 to i32 lowering for wasm2asm (#1134)
Diffstat (limited to 'src/passes/I64ToI32Lowering.cpp')
-rw-r--r-- | src/passes/I64ToI32Lowering.cpp | 1197 |
1 files changed, 1197 insertions, 0 deletions
diff --git a/src/passes/I64ToI32Lowering.cpp b/src/passes/I64ToI32Lowering.cpp new file mode 100644 index 000000000..65871ca6f --- /dev/null +++ b/src/passes/I64ToI32Lowering.cpp @@ -0,0 +1,1197 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +// Lowers i64s to i32s by splitting variables and arguments +// into pairs of i32s. i64 return values are lowered by +// returning the low half and storing the high half into a +// global. +// + +#include <algorithm> +#include "wasm.h" +#include "pass.h" +#include "emscripten-optimizer/istring.h" +#include "support/name.h" +#include "wasm-builder.h" + + +namespace wasm { + +static Name makeHighName(Name n) { + return Name( + cashew::IString((std::string(n.c_str()) + "$hi").c_str(), false) + ); +} + +struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> { + static Name highBitsGlobal; + + // false since function types need to be lowered + // TODO: allow module-level transformations in parallel passes + bool isFunctionParallel() override { return false; } + + Pass* create() override { + return new I64ToI32Lowering; + } + + void doWalkModule(Module* module) { + if (!builder) builder = make_unique<Builder>(*module); + // add new globals for high bits + for (size_t i = 0, globals = module->globals.size(); i < globals; ++i) { + auto& curr = module->globals[i]; + if (curr->type != i64) continue; + curr->type = i32; + auto* high = new Global(*curr); + high->name = makeHighName(curr->name); + module->addGlobal(high); + } + PostWalker<I64ToI32Lowering>::doWalkModule(module); + } + + void visitFunctionType(FunctionType* curr) { + std::vector<WasmType> params; + for (auto t : curr->params) { + if (t == i64) { + params.push_back(i32); + params.push_back(i32); + } else { + params.push_back(t); + } + } + std::swap(params, curr->params); + if (curr->result == i64) { + curr->result = i32; + } + } + + void doWalkFunction(Function* func) { + // create builder here if this is first entry to module for this object + if (!builder) builder = make_unique<Builder>(*getModule()); + indexMap.clear(); + returnIndices.clear(); + labelIndices.clear(); + freeTemps.clear(); + Function oldFunc(*func); + func->params.clear(); + func->vars.clear(); + func->localNames.clear(); + func->localIndices.clear(); + Index newIdx = 0; + for (Index i = 0; i < oldFunc.getNumLocals(); ++i) { + assert(oldFunc.hasLocalName(i)); + Name lowName = oldFunc.getLocalName(i); + Name highName = makeHighName(lowName); + WasmType paramType = oldFunc.getLocalType(i); + auto builderFunc = (i < oldFunc.getVarIndexBase()) ? + Builder::addParam : + static_cast<Index (*)(Function*, Name, WasmType)>(Builder::addVar); + if (paramType == i64) { + builderFunc(func, lowName, i32); + builderFunc(func, highName, i32); + indexMap[i] = newIdx; + newIdx += 2; + } else { + builderFunc(func, lowName, paramType); + indexMap[i] = newIdx++; + } + } + nextTemp = func->getNumLocals(); + PostWalker<I64ToI32Lowering>::doWalkFunction(func); + } + + void visitFunction(Function* func) { + if (func->result == i64) { + func->result = i32; + // body may not have out param if it ends with control flow + if (hasOutParam(func->body)) { + Index highBits = fetchOutParam(func->body); + Index lowBits = getTemp(); + SetLocal* setLow = builder->makeSetLocal( + lowBits, + func->body + ); + SetGlobal* setHigh = builder->makeSetGlobal( + highBitsGlobal, + builder->makeGetLocal(highBits, i32) + ); + GetLocal* getLow = builder->makeGetLocal(lowBits, i32); + func->body = builder->blockify(setLow, setHigh, getLow); + freeTemp(highBits); + freeTemp(lowBits); + } + } + assert(freeTemps.size() == nextTemp - func->getNumLocals()); + int idx = 0; + for (size_t i = func->getNumLocals(); i < nextTemp; i++) { + Name tmpName("i64toi32_i32$" + std::to_string(idx++)); + builder->addVar(func, tmpName, i32); + } + } + + void visitBlock(Block* curr) { + if (curr->list.size() == 0) return; + if (curr->type == i64) curr->type = i32; + auto highBitsIt = labelIndices.find(curr->name); + if (!hasOutParam(curr->list.back())) { + if (highBitsIt != labelIndices.end()) { + setOutParam(curr, highBitsIt->second); + } + return; + } + Index lastHighBits = fetchOutParam(curr->list.back()); + if (highBitsIt == labelIndices.end() || + highBitsIt->second == lastHighBits) { + setOutParam(curr, lastHighBits); + if (highBitsIt != labelIndices.end()) { + labelIndices.erase(highBitsIt); + } + return; + } + Index highBits = highBitsIt->second; + Index tmp = getTemp(); + labelIndices.erase(highBitsIt); + SetLocal* setLow = builder->makeSetLocal(tmp, curr->list.back()); + SetLocal* setHigh = builder->makeSetLocal( + highBits, + builder->makeGetLocal(lastHighBits, i32) + ); + GetLocal* getLow = builder->makeGetLocal(tmp, i32); + curr->list.back() = builder->blockify(setLow, setHigh, getLow); + setOutParam(curr, highBits); + freeTemp(lastHighBits); + freeTemp(tmp); + } + + // If and Select have identical code + template <typename T> + void visitBranching(T* curr) { + if (!hasOutParam(curr->ifTrue)) return; + assert(curr->ifFalse != nullptr && "Nullable ifFalse found"); + Index highBits = fetchOutParam(curr->ifTrue); + Index falseBits = fetchOutParam(curr->ifFalse); + Index tmp = getTemp(); + curr->type = i32; + curr->ifFalse = builder->blockify( + builder->makeSetLocal(tmp, curr->ifFalse), + builder->makeSetLocal( + highBits, + builder->makeGetLocal(falseBits, i32) + ), + builder->makeGetLocal(tmp, i32) + ); + freeTemp(tmp); + freeTemp(falseBits); + setOutParam(curr, highBits); + } + + void visitIf(If* curr) { + visitBranching<If>(curr); + } + + void visitLoop(Loop* curr) { + assert(labelIndices.find(curr->name) == labelIndices.end()); + if (curr->type != i64) return; + curr->type = i32; + setOutParam(curr, fetchOutParam(curr->body)); + } + + void visitBreak(Break* curr) { + if (!hasOutParam(curr->value)) return; + assert(curr->value != nullptr); + Index valHighBits = fetchOutParam(curr->value); + auto blockHighBitsIt = labelIndices.find(curr->name); + if (blockHighBitsIt == labelIndices.end()) { + labelIndices[curr->name] = valHighBits; + curr->type = i32; + return; + } + Index blockHighBits = blockHighBitsIt->second; + Index tmp = getTemp(); + SetLocal* setLow = builder->makeSetLocal(tmp, curr->value); + SetLocal* setHigh = builder->makeSetLocal( + blockHighBits, + builder->makeGetLocal(valHighBits, i32) + ); + curr->value = builder->makeGetLocal(tmp, i32); + curr->type = i32; + replaceCurrent(builder->blockify(setLow, setHigh, curr)); + freeTemp(tmp); + } + + void visitSwitch(Switch* curr) { + if (!hasOutParam(curr->value)) return; + Index outParam = fetchOutParam(curr->value); + Index tmp = getTemp(); + bool didReuseOutParam = false; + Expression* result = curr; + std::vector<Name> targets; + auto processTarget = [&](Name target) -> Name { + auto labelIt = labelIndices.find(target); + if (labelIt == labelIndices.end() || labelIt->second == outParam) { + labelIndices[target] = outParam; + didReuseOutParam = true; + return target; + } + Index labelOutParam = labelIt->second; + Name newLabel("$i64toi32_" + std::string(target.c_str())); + result = builder->blockify( + builder->makeSetLocal(tmp, builder->makeBlock(newLabel, result)), + builder->makeSetLocal( + labelOutParam, + builder->makeGetLocal(outParam, i32) + ), + builder->makeGetLocal(tmp, i32) + ); + assert(result->type == i32); + return newLabel; + }; + for (auto target : curr->targets) { + targets.push_back(processTarget(target)); + } + curr->default_ = processTarget(curr->default_); + replaceCurrent(result); + freeTemp(tmp); + if (!didReuseOutParam) { + freeTemp(outParam); + } + } + + template <typename T> + using BuilderFunc = std::function<T*(std::vector<Expression*>&, WasmType)>; + + template <typename T> + void visitGenericCall(T* curr, BuilderFunc<T> callBuilder) { + std::vector<Expression*> args; + for (auto* e : curr->operands) { + args.push_back(e); + if (hasOutParam(e)) { + Index argHighBits = fetchOutParam(e); + args.push_back(builder->makeGetLocal(argHighBits, i32)); + freeTemp(argHighBits); + } + } + if (curr->type != i64) { + replaceCurrent(callBuilder(args, curr->type)); + return; + } + Index lowBits = getTemp(); + Index highBits = getTemp(); + SetLocal* doCall = builder->makeSetLocal( + lowBits, + callBuilder(args, i32) + ); + SetLocal* setHigh = builder->makeSetLocal( + highBits, + builder->makeGetGlobal(highBitsGlobal, i32) + ); + GetLocal* getLow = builder->makeGetLocal(lowBits, i32); + Block* result = builder->blockify(doCall, setHigh, getLow); + freeTemp(lowBits); + setOutParam(result, highBits); + replaceCurrent(result); + } + void visitCall(Call* curr) { + visitGenericCall<Call>( + curr, + [&](std::vector<Expression*>& args, WasmType ty) { + return builder->makeCall(curr->target, args, ty); + } + ); + } + + void visitCallImport(CallImport* curr) { + // imports cannot contain i64s + return; + } + + void visitCallIndirect(CallIndirect* curr) { + visitGenericCall<CallIndirect>( + curr, + [&](std::vector<Expression*>& args, WasmType ty) { + return builder->makeCallIndirect( + curr->fullType, + curr->target, + args, + ty + ); + } + ); + } + + void visitGetLocal(GetLocal* curr) { + if (curr->type != i64) return; + curr->index = indexMap[curr->index]; + curr->type = i32; + Index highBits = getTemp(); + SetLocal *setHighBits = builder->makeSetLocal( + highBits, + builder->makeGetLocal( + curr->index + 1, + i32 + ) + ); + Block* result = builder->blockify(setHighBits, curr); + replaceCurrent(result); + setOutParam(result, highBits); + } + + void lowerTee(SetLocal* curr) { + Index highBits = fetchOutParam(curr->value); + Index tmp = getTemp(); + curr->index = indexMap[curr->index]; + curr->type = i32; + SetLocal* setLow = builder->makeSetLocal(tmp, curr); + SetLocal* setHigh = builder->makeSetLocal( + curr->index + 1, + builder->makeGetLocal(highBits, i32) + ); + GetLocal* getLow = builder->makeGetLocal(tmp, i32); + Block* result = builder->blockify(setLow, setHigh, getLow); + replaceCurrent(result); + setOutParam(result, highBits); + freeTemp(tmp); + } + + void visitSetLocal(SetLocal* curr) { + if (!hasOutParam(curr->value)) return; + if (curr->isTee()) { + lowerTee(curr); + return; + } + Index highBits = fetchOutParam(curr->value); + curr->index = indexMap[curr->index]; + SetLocal* setHigh = builder->makeSetLocal( + curr->index + 1, + builder->makeGetLocal(highBits, i32) + ); + Block* result = builder->blockify(curr, setHigh); + replaceCurrent(result); + freeTemp(highBits); + } + + void visitGetGlobal(GetGlobal* curr) { + assert(false && "GetGlobal not implemented"); + } + + void visitSetGlobal(SetGlobal* curr) { + assert(false && "SetGlobal not implemented"); + } + + void visitLoad(Load* curr) { + if (curr->type != i64) return; + assert(!curr->isAtomic && "atomic load not implemented"); + Index highBits = getTemp(); + Index ptrTemp = getTemp(); + SetLocal* setPtr = builder->makeSetLocal(ptrTemp, curr->ptr); + SetLocal* loadHigh; + if (curr->bytes == 8) { + loadHigh = builder->makeSetLocal( + highBits, + builder->makeLoad( + 4, + curr->signed_, + curr->offset + 4, + 1, + builder->makeGetLocal(ptrTemp, i32), + i32 + ) + ); + } else { + loadHigh = builder->makeSetLocal( + highBits, + builder->makeConst(Literal(int32_t(0))) + ); + } + curr->type = i32; + curr->bytes = std::min(curr->bytes, uint8_t(4)); + curr->align = std::min(uint32_t(curr->align), uint32_t(4)); + curr->ptr = builder->makeGetLocal(ptrTemp, i32); + Block* result = builder->blockify(setPtr, loadHigh, curr); + replaceCurrent(result); + setOutParam(result, highBits); + freeTemp(ptrTemp); + } + + void visitStore(Store* curr) { + if (!hasOutParam(curr->value)) return; + assert(curr->offset + 4 > curr->offset); + assert(!curr->isAtomic && "atomic store not implemented"); + Index highBits = fetchOutParam(curr->value); + uint8_t bytes = curr->bytes; + curr->bytes = std::min(curr->bytes, uint8_t(4)); + curr->align = std::min(uint32_t(curr->align), uint32_t(4)); + curr->valueType = i32; + if (bytes == 8) { + Index ptrTemp = getTemp(); + SetLocal* setPtr = builder->makeSetLocal(ptrTemp, curr->ptr); + curr->ptr = builder->makeGetLocal(ptrTemp, i32); + Store* storeHigh = builder->makeStore( + 4, + curr->offset + 4, + 1, + builder->makeGetLocal(ptrTemp, i32), + builder->makeGetLocal(highBits, i32), + i32 + ); + replaceCurrent(builder->blockify(setPtr, curr, storeHigh)); + freeTemp(ptrTemp); + } + freeTemp(highBits); + } + + void visitAtomicRMW(AtomicRMW* curr) { + assert(false && "AtomicRMW not implemented"); + } + + void visitAtomicCmpxchg(AtomicCmpxchg* curr) { + assert(false && "AtomicCmpxchg not implemented"); + } + + void visitConst(Const* curr) { + if (curr->type != i64) return; + Index highBits = getTemp(); + Const* lowVal = builder->makeConst( + Literal(int32_t(curr->value.geti64() & 0xffffffff)) + ); + SetLocal* setHigh = builder->makeSetLocal( + highBits, + builder->makeConst( + Literal(int32_t(uint64_t(curr->value.geti64()) >> 32)) + ) + ); + Block* result = builder->blockify(setHigh, lowVal); + setOutParam(result, highBits); + replaceCurrent(result); + } + + void lowerEqZInt64(Unary* curr) { + Index highBits = fetchOutParam(curr->value); + replaceCurrent( + builder->makeBinary( + AndInt32, + builder->makeUnary(EqZInt32, builder->makeGetLocal(highBits, i32)), + builder->makeUnary(EqZInt32, curr->value) + ) + ); + freeTemp(highBits); + } + + void lowerExtendUInt32(Unary* curr) { + Index highBits = getTemp(); + Block* result = builder->blockify( + builder->makeSetLocal(highBits, builder->makeConst(Literal(int32_t(0)))), + curr->value + ); + setOutParam(result, highBits); + replaceCurrent(result); + } + + void lowerWrapInt64(Unary* curr) { + freeTemp(fetchOutParam(curr->value)); + replaceCurrent(curr->value); + } + + bool unaryNeedsLowering(UnaryOp op) { + switch (op) { + case ClzInt64: + case CtzInt64: + case PopcntInt64: + case EqZInt64: + case ExtendSInt32: + case ExtendUInt32: + case WrapInt64: + case TruncSFloat32ToInt64: + case TruncUFloat32ToInt64: + case TruncSFloat64ToInt64: + case TruncUFloat64ToInt64: + case ReinterpretFloat64: + case ConvertSInt64ToFloat32: + case ConvertSInt64ToFloat64: + case ConvertUInt64ToFloat32: + case ConvertUInt64ToFloat64: + case ReinterpretInt64: return true; + default: return false; + } + } + + void visitUnary(Unary* curr) { + if (!unaryNeedsLowering(curr->op)) return; + if (curr->type == unreachable || curr->value->type == unreachable) { + assert(!hasOutParam(curr->value)); + replaceCurrent(curr->value); + return; + } + assert(hasOutParam(curr->value) || curr->type == i64); + switch (curr->op) { + case ClzInt64: + case CtzInt64: + case PopcntInt64: goto err; + case EqZInt64: lowerEqZInt64(curr); break; + case ExtendSInt32: goto err; + case ExtendUInt32: lowerExtendUInt32(curr); break; + case WrapInt64: lowerWrapInt64(curr); break; + case TruncSFloat32ToInt64: + case TruncUFloat32ToInt64: + case TruncSFloat64ToInt64: + case TruncUFloat64ToInt64: + case ReinterpretFloat64: + case ConvertSInt64ToFloat32: + case ConvertSInt64ToFloat64: + case ConvertUInt64ToFloat32: + case ConvertUInt64ToFloat64: + case ReinterpretInt64: + err: default: + std::cerr << "Unhandled unary operator: " << curr->op << std::endl; + abort(); + } + } + + Block* lowerAdd(Block* result, Index leftLow, Index leftHigh, + Index rightLow, Index rightHigh) { + SetLocal* addLow = builder->makeSetLocal( + leftHigh, + builder->makeBinary( + AddInt32, + builder->makeGetLocal(leftLow, i32), + builder->makeGetLocal(rightLow, i32) + ) + ); + SetLocal* addHigh = builder->makeSetLocal( + rightHigh, + builder->makeBinary( + AddInt32, + builder->makeGetLocal(leftHigh, i32), + builder->makeGetLocal(rightHigh, i32) + ) + ); + SetLocal* carryBit = builder->makeSetLocal( + rightHigh, + builder->makeBinary( + AddInt32, + builder->makeGetLocal(rightHigh, i32), + builder->makeConst(Literal(int32_t(1))) + ) + ); + If* checkOverflow = builder->makeIf( + builder->makeBinary( + LtUInt32, + builder->makeGetLocal(leftLow, i32), + builder->makeGetLocal(rightLow, i32) + ), + carryBit + ); + GetLocal* getLow = builder->makeGetLocal(leftHigh, i32); + result = builder->blockify(result, addLow, addHigh, checkOverflow, getLow); + freeTemp(leftLow); + freeTemp(leftHigh); + freeTemp(rightLow); + setOutParam(result, rightHigh); + return result; + } + + Block* lowerMul(Block* result, Index leftLow, Index leftHigh, Index rightLow, + Index rightHigh) { + // high bits = ll*rh + lh*rl + ll1*rl1 + (ll0*rl1)>>16 + (ll1*rl0)>>16 + // low bits = ll*rl + Index leftLow0 = getTemp(); + Index leftLow1 = getTemp(); + Index rightLow0 = getTemp(); + Index rightLow1 = getTemp(); + SetLocal* setLL0 = builder->makeSetLocal( + leftLow0, + builder->makeBinary( + AndInt32, + builder->makeGetLocal(leftLow, i32), + builder->makeConst(Literal(int32_t(0xffff))) + ) + ); + SetLocal* setLL1 = builder->makeSetLocal( + leftLow1, + builder->makeBinary( + ShrUInt32, + builder->makeGetLocal(leftLow, i32), + builder->makeConst(Literal(int32_t(16))) + ) + ); + SetLocal* setRL0 = builder->makeSetLocal( + rightLow0, + builder->makeBinary( + AndInt32, + builder->makeGetLocal(rightLow, i32), + builder->makeConst(Literal(int32_t(0xffff))) + ) + ); + SetLocal* setRL1 = builder->makeSetLocal( + rightLow1, + builder->makeBinary( + ShrUInt32, + builder->makeGetLocal(rightLow, i32), + builder->makeConst(Literal(int32_t(16))) + ) + ); + SetLocal* setLLRH = builder->makeSetLocal( + rightHigh, + builder->makeBinary( + MulInt32, + builder->makeGetLocal(leftLow, i32), + builder->makeGetLocal(rightHigh, i32) + ) + ); + auto addToHighBits = [&](Expression* expr) -> SetLocal* { + return builder->makeSetLocal( + rightHigh, + builder->makeBinary( + AddInt32, + builder->makeGetLocal(rightHigh, i32), + expr + ) + ); + }; + SetLocal* addLHRL = addToHighBits( + builder->makeBinary( + MulInt32, + builder->makeGetLocal(leftHigh, i32), + builder->makeGetLocal(rightLow, i32) + ) + ); + SetLocal* addLL1RL1 = addToHighBits( + builder->makeBinary( + MulInt32, + builder->makeGetLocal(leftLow1, i32), + builder->makeGetLocal(rightLow1, i32) + ) + ); + SetLocal* addLL0RL1 = addToHighBits( + builder->makeBinary( + ShrUInt32, + builder->makeBinary( + MulInt32, + builder->makeGetLocal(leftLow0, i32), + builder->makeGetLocal(rightLow1, i32) + ), + builder->makeConst(Literal(int32_t(16))) + ) + ); + SetLocal* addLL1RL0 = addToHighBits( + builder->makeBinary( + ShrUInt32, + builder->makeBinary( + MulInt32, + builder->makeGetLocal(leftLow1, i32), + builder->makeGetLocal(rightLow0, i32) + ), + builder->makeConst(Literal(int32_t(16))) + ) + ); + Binary* getLow = builder->makeBinary( + MulInt32, + builder->makeGetLocal(leftLow, i32), + builder->makeGetLocal(rightLow, i32) + ); + result = builder->blockify( + result, + setLL0, + setLL1, + setRL0, + setRL1, + setLLRH, + addLHRL, + addLL1RL1, + addLL0RL1, + addLL1RL0, + getLow + ); + freeTemp(leftLow0); + freeTemp(leftLow1); + freeTemp(rightLow0); + freeTemp(rightLow1); + freeTemp(leftLow); + freeTemp(leftHigh); + freeTemp(rightLow); + setOutParam(result, rightHigh); + return result; + } + + Block* lowerBitwise(BinaryOp op, Block* result, Index leftLow, Index leftHigh, + Index rightLow, Index rightHigh) { + BinaryOp op32; + switch (op) { + case AndInt64: op32 = AndInt32; break; + case OrInt64: op32 = OrInt32; break; + case XorInt64: op32 = XorInt32; break; + default: abort(); + } + result = builder->blockify( + result, + builder->makeSetLocal( + rightHigh, + builder->makeBinary( + op32, + builder->makeGetLocal(leftHigh, i32), + builder->makeGetLocal(rightHigh, i32) + ) + ), + builder->makeBinary( + op32, + builder->makeGetLocal(leftLow, i32), + builder->makeGetLocal(rightLow, i32) + ) + ); + freeTemp(leftLow); + freeTemp(leftHigh); + freeTemp(rightLow); + setOutParam(result, rightHigh); + return result; + } + + Block* makeLargeShl(Index highBits, Index leftLow, Index shift) { + return builder->blockify( + builder->makeSetLocal( + highBits, + builder->makeBinary( + ShlInt32, + builder->makeGetLocal(leftLow, i32), + builder->makeGetLocal(shift, i32) + ) + ), + builder->makeConst(Literal(int32_t(0))) + ); + } + + Block* makeLargeShrU(Index highBits, Index leftHigh, Index shift) { + return builder->blockify( + builder->makeSetLocal(highBits, builder->makeConst(Literal(int32_t(0)))), + builder->makeBinary( + ShrUInt32, + builder->makeGetLocal(leftHigh, i32), + builder->makeGetLocal(shift, i32) + ) + ); + } + + Block* makeSmallShl(Index highBits, Index leftLow, Index leftHigh, + Index shift, Binary* shiftMask, Binary* widthLessShift) { + Binary* shiftedInBits = builder->makeBinary( + AndInt32, + shiftMask, + builder->makeBinary( + ShrUInt32, + builder->makeGetLocal(leftLow, i32), + widthLessShift + ) + ); + Binary* shiftHigh = builder->makeBinary( + ShlInt32, + builder->makeGetLocal(leftHigh, i32), + builder->makeGetLocal(shift, i32) + ); + return builder->blockify( + builder->makeSetLocal( + highBits, + builder->makeBinary(OrInt32, shiftedInBits, shiftHigh) + ), + builder->makeBinary( + ShlInt32, + builder->makeGetLocal(leftLow, i32), + builder->makeGetLocal(shift, i32) + ) + ); + } + + Block* makeSmallShrU(Index highBits, Index leftLow, Index leftHigh, + Index shift, Binary* shiftMask, Binary* widthLessShift) { + Binary* shiftedInBits = builder->makeBinary( + ShlInt32, + builder->makeBinary( + AndInt32, + shiftMask, + builder->makeGetLocal(leftHigh, i32) + ), + widthLessShift + ); + Binary* shiftLow = builder->makeBinary( + ShrUInt32, + builder->makeGetLocal(leftLow, i32), + builder->makeGetLocal(shift, i32) + ); + return builder->blockify( + builder->makeSetLocal( + highBits, + builder->makeBinary( + ShrUInt32, + builder->makeGetLocal(leftHigh, i32), + builder->makeGetLocal(shift, i32) + ) + ), + builder->makeBinary(OrInt32, shiftedInBits, shiftLow) + ); + } + + Block* lowerShU(BinaryOp op, Block* result, Index leftLow, Index leftHigh, + Index rightLow, Index rightHigh) { + assert(op == ShlInt64 || op == ShrUInt64); + // shift left lowered as: + // if 32 <= rightLow % 64: + // high = leftLow << k; low = 0 + // else: + // high = (((1 << k) - 1) & (leftLow >> (32 - k))) | (leftHigh << k); + // low = leftLow << k + // where k = shift % 32. shift right is similar. + Index shift = getTemp(); + SetLocal* setShift = builder->makeSetLocal( + shift, + builder->makeBinary( + AndInt32, + builder->makeGetLocal(rightLow, i32), + builder->makeConst(Literal(int32_t(32 - 1))) + ) + ); + Binary* isLargeShift = builder->makeBinary( + LeUInt32, + builder->makeConst(Literal(int32_t(32))), + builder->makeBinary( + AndInt32, + builder->makeGetLocal(rightLow, i32), + builder->makeConst(Literal(int32_t(64 - 1))) + ) + ); + Block* largeShiftBlock; + switch (op) { + case ShlInt64: + largeShiftBlock = makeLargeShl(rightHigh, leftLow, shift); break; + case ShrUInt64: + largeShiftBlock = makeLargeShrU(rightHigh, leftHigh, shift); break; + default: abort(); + } + Binary* shiftMask = builder->makeBinary( + SubInt32, + builder->makeBinary( + ShlInt32, + builder->makeConst(Literal(int32_t(1))), + builder->makeGetLocal(shift, i32) + ), + builder->makeConst(Literal(int32_t(1))) + ); + Binary* widthLessShift = builder->makeBinary( + SubInt32, + builder->makeConst(Literal(int32_t(32))), + builder->makeGetLocal(shift, i32) + ); + Block* smallShiftBlock; + switch(op) { + case ShlInt64: { + smallShiftBlock = makeSmallShl(rightHigh, leftLow, leftHigh, + shift, shiftMask, widthLessShift); + break; + } + case ShrUInt64: { + smallShiftBlock = makeSmallShrU(rightHigh, leftLow, leftHigh, + shift, shiftMask, widthLessShift); + break; + } + default: abort(); + } + If* ifLargeShift = builder->makeIf( + isLargeShift, + largeShiftBlock, + smallShiftBlock + ); + result = builder->blockify(result, setShift, ifLargeShift); + freeTemp(shift); + freeTemp(leftLow); + freeTemp(leftHigh); + freeTemp(rightLow); + setOutParam(result, rightHigh); + return result; + } + + Block* lowerEq(Block* result, Index leftLow, Index leftHigh, + Index rightLow, Index rightHigh) { + freeTemp(leftLow); + freeTemp(leftHigh); + freeTemp(rightLow); + freeTemp(rightHigh); + return builder->blockify( + result, + builder->makeBinary( + AndInt32, + builder->makeBinary( + EqInt32, + builder->makeGetLocal(leftLow, i32), + builder->makeGetLocal(rightLow, i32) + ), + builder->makeBinary( + EqInt32, + builder->makeGetLocal(leftHigh, i32), + builder->makeGetLocal(rightHigh, i32) + ) + ) + ); + } + + Block* lowerNe(Block* result, Index leftLow, Index leftHigh, + Index rightLow, Index rightHigh) { + freeTemp(leftLow); + freeTemp(leftHigh); + freeTemp(rightLow); + freeTemp(rightHigh); + return builder->blockify( + result, + builder->makeBinary( + OrInt32, + builder->makeBinary( + NeInt32, + builder->makeGetLocal(leftLow, i32), + builder->makeGetLocal(rightLow, i32) + ), + builder->makeBinary( + NeInt32, + builder->makeGetLocal(leftHigh, i32), + builder->makeGetLocal(rightHigh, i32) + ) + ) + ); + } + + Block* lowerUComp(BinaryOp op, Block* result, Index leftLow, Index leftHigh, + Index rightLow, Index rightHigh) { + freeTemp(leftLow); + freeTemp(leftHigh); + freeTemp(rightLow); + freeTemp(rightHigh); + BinaryOp highOp, lowOp; + switch (op) { + case LtUInt64: highOp = LtUInt32; lowOp = LtUInt32; break; + case LeUInt64: highOp = LtUInt32; lowOp = LeUInt32; break; + case GtUInt64: highOp = GtUInt32; lowOp = GtUInt32; break; + case GeUInt64: highOp = GtUInt32; lowOp = GeUInt32; break; + default: abort(); + } + Binary* compHigh = builder->makeBinary( + highOp, + builder->makeGetLocal(leftHigh, i32), + builder->makeGetLocal(rightHigh, i32) + ); + Binary* eqHigh = builder->makeBinary( + EqInt32, + builder->makeGetLocal(leftHigh, i32), + builder->makeGetLocal(rightHigh, i32) + ); + Binary* compLow = builder->makeBinary( + lowOp, + builder->makeGetLocal(leftLow, i32), + builder->makeGetLocal(rightLow, i32) + ); + return builder->blockify( + result, + builder->makeBinary( + OrInt32, + compHigh, + builder->makeBinary(AndInt32, eqHigh, compLow) + ) + ); + } + + bool binaryNeedsLowering(BinaryOp op) { + switch (op) { + case AddInt64: + case SubInt64: + case MulInt64: + case DivSInt64: + case DivUInt64: + case RemSInt64: + case RemUInt64: + case AndInt64: + case OrInt64: + case XorInt64: + case ShlInt64: + case ShrUInt64: + case ShrSInt64: + case RotLInt64: + case RotRInt64: + case EqInt64: + case NeInt64: + case LtSInt64: + case LtUInt64: + case LeSInt64: + case LeUInt64: + case GtSInt64: + case GtUInt64: + case GeSInt64: + case GeUInt64: return true; + default: return false; + } + } + + void visitBinary(Binary* curr) { + if (!binaryNeedsLowering(curr->op)) return; + if (!hasOutParam(curr->left)) { + // left unreachable, replace self with left + replaceCurrent(curr->left); + if (hasOutParam(curr->right)) { + freeTemp(fetchOutParam(curr->right)); + } + return; + } + if (!hasOutParam(curr->right)) { + // right unreachable, replace self with left then right + replaceCurrent( + builder->blockify(builder->makeDrop(curr->left), curr->right) + ); + freeTemp(fetchOutParam(curr->left)); + return; + } + // left and right reachable, lower normally + Index leftLow = getTemp(); + Index leftHigh = fetchOutParam(curr->left); + Index rightLow = getTemp(); + Index rightHigh = fetchOutParam(curr->right); + SetLocal* setRight = builder->makeSetLocal(rightLow, curr->right); + SetLocal* setLeft = builder->makeSetLocal(leftLow, curr->left); + Block* result = builder->blockify(setLeft, setRight); + switch (curr->op) { + case AddInt64: { + replaceCurrent(lowerAdd(result, leftLow, leftHigh, rightLow, + rightHigh)); + break; + } + case SubInt64: goto err; + case MulInt64: { + replaceCurrent(lowerMul(result, leftLow, leftHigh, rightLow, + rightHigh)); + break; + } + case DivSInt64: + case DivUInt64: + case RemSInt64: + case RemUInt64: goto err; + case AndInt64: + case OrInt64: + case XorInt64: { + replaceCurrent(lowerBitwise(curr->op, result, leftLow, leftHigh, + rightLow, rightHigh)); + break; + } + case ShlInt64: + case ShrUInt64: { + replaceCurrent(lowerShU(curr->op, result, leftLow, leftHigh, rightLow, + rightHigh)); + break; + } + case ShrSInt64: + case RotLInt64: + case RotRInt64: goto err; + case EqInt64: { + replaceCurrent(lowerEq(result, leftLow, leftHigh, rightLow, rightHigh)); + break; + } + case NeInt64: { + replaceCurrent(lowerNe(result, leftLow, leftHigh, rightLow, rightHigh)); + break; + } + case LtSInt64: + case LeSInt64: + case GtSInt64: + case GeSInt64: goto err; + case LtUInt64: + case LeUInt64: + case GtUInt64: + case GeUInt64: { + replaceCurrent(lowerUComp(curr->op, result, leftLow, leftHigh, rightLow, + rightHigh)); + break; + } + err: default: { + std::cerr << "Unhandled binary op " << curr->op << std::endl; + abort(); + } + } + } + + void visitSelect(Select* curr) { + visitBranching<Select>(curr); + } + + void visitDrop(Drop* curr) { + if (!hasOutParam(curr->value)) return; + freeTemp(fetchOutParam(curr->value)); + } + + void visitReturn(Return* curr) { + if (!hasOutParam(curr->value)) return; + Index lowBits = getTemp(); + Index highBits = fetchOutParam(curr->value); + SetLocal* setLow = builder->makeSetLocal(lowBits, curr->value); + SetGlobal* setHigh = builder->makeSetGlobal( + highBitsGlobal, + builder->makeGetLocal(highBits, i32) + ); + curr->value = builder->makeGetLocal(lowBits, i32); + Block* result = builder->blockify(setLow, setHigh, curr); + replaceCurrent(result); + freeTemp(lowBits); + freeTemp(highBits); + } + +private: + std::unique_ptr<Builder> builder; + std::unordered_map<Index, Index> indexMap; + std::unordered_map<Expression*, Index> returnIndices; + std::unordered_map<Name, Index> labelIndices; + std::vector<Index> freeTemps; + Index nextTemp; + + // TODO: RAII for temp var allocation + Index getTemp() { + Index ret; + if (freeTemps.size() > 0) { + ret = freeTemps.back(); + freeTemps.pop_back(); + } else { + ret = nextTemp++; + } + return ret; + } + + void freeTemp(Index t) { + assert(std::find(freeTemps.begin(), freeTemps.end(), t) == freeTemps.end()); + freeTemps.push_back(t); + } + + bool hasOutParam(Expression* e) { + return returnIndices.find(e) != returnIndices.end(); + } + + void setOutParam(Expression* e, Index idx) { + returnIndices[e] = idx; + } + + Index fetchOutParam(Expression* e) { + assert(returnIndices.find(e) != returnIndices.end()); + Index ret = returnIndices[e]; + returnIndices.erase(e); + return ret; + } +}; + +Name I64ToI32Lowering::highBitsGlobal("i64toi32_i32$HIGH_BITS"); + +Pass *createI64ToI32LoweringPass() { + return new I64ToI32Lowering(); +} + +} |