summaryrefslogtreecommitdiff
path: root/src/passes/I64ToI32Lowering.cpp
diff options
context:
space:
mode:
authorThomas Lively <7121787+tlively@users.noreply.github.com>2017-09-01 14:26:01 -0400
committerAlon Zakai <alonzakai@gmail.com>2017-09-01 11:26:01 -0700
commitb013f744e3d70effd9be348cbde7fb93f0a16c6a (patch)
tree3b122293005d3370c931175eed92ad61e9dfd851 /src/passes/I64ToI32Lowering.cpp
parentb1e8b1b515b2a1d0264975abc4de39c8044f7195 (diff)
downloadbinaryen-b013f744e3d70effd9be348cbde7fb93f0a16c6a.tar.gz
binaryen-b013f744e3d70effd9be348cbde7fb93f0a16c6a.tar.bz2
binaryen-b013f744e3d70effd9be348cbde7fb93f0a16c6a.zip
i64 to i32 lowering for wasm2asm (#1134)
Diffstat (limited to 'src/passes/I64ToI32Lowering.cpp')
-rw-r--r--src/passes/I64ToI32Lowering.cpp1197
1 files changed, 1197 insertions, 0 deletions
diff --git a/src/passes/I64ToI32Lowering.cpp b/src/passes/I64ToI32Lowering.cpp
new file mode 100644
index 000000000..65871ca6f
--- /dev/null
+++ b/src/passes/I64ToI32Lowering.cpp
@@ -0,0 +1,1197 @@
+/*
+ * Copyright 2017 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//
+// Lowers i64s to i32s by splitting variables and arguments
+// into pairs of i32s. i64 return values are lowered by
+// returning the low half and storing the high half into a
+// global.
+//
+
+#include <algorithm>
+#include "wasm.h"
+#include "pass.h"
+#include "emscripten-optimizer/istring.h"
+#include "support/name.h"
+#include "wasm-builder.h"
+
+
+namespace wasm {
+
+static Name makeHighName(Name n) {
+ return Name(
+ cashew::IString((std::string(n.c_str()) + "$hi").c_str(), false)
+ );
+}
+
+struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> {
+ static Name highBitsGlobal;
+
+ // false since function types need to be lowered
+ // TODO: allow module-level transformations in parallel passes
+ bool isFunctionParallel() override { return false; }
+
+ Pass* create() override {
+ return new I64ToI32Lowering;
+ }
+
+ void doWalkModule(Module* module) {
+ if (!builder) builder = make_unique<Builder>(*module);
+ // add new globals for high bits
+ for (size_t i = 0, globals = module->globals.size(); i < globals; ++i) {
+ auto& curr = module->globals[i];
+ if (curr->type != i64) continue;
+ curr->type = i32;
+ auto* high = new Global(*curr);
+ high->name = makeHighName(curr->name);
+ module->addGlobal(high);
+ }
+ PostWalker<I64ToI32Lowering>::doWalkModule(module);
+ }
+
+ void visitFunctionType(FunctionType* curr) {
+ std::vector<WasmType> params;
+ for (auto t : curr->params) {
+ if (t == i64) {
+ params.push_back(i32);
+ params.push_back(i32);
+ } else {
+ params.push_back(t);
+ }
+ }
+ std::swap(params, curr->params);
+ if (curr->result == i64) {
+ curr->result = i32;
+ }
+ }
+
+ void doWalkFunction(Function* func) {
+ // create builder here if this is first entry to module for this object
+ if (!builder) builder = make_unique<Builder>(*getModule());
+ indexMap.clear();
+ returnIndices.clear();
+ labelIndices.clear();
+ freeTemps.clear();
+ Function oldFunc(*func);
+ func->params.clear();
+ func->vars.clear();
+ func->localNames.clear();
+ func->localIndices.clear();
+ Index newIdx = 0;
+ for (Index i = 0; i < oldFunc.getNumLocals(); ++i) {
+ assert(oldFunc.hasLocalName(i));
+ Name lowName = oldFunc.getLocalName(i);
+ Name highName = makeHighName(lowName);
+ WasmType paramType = oldFunc.getLocalType(i);
+ auto builderFunc = (i < oldFunc.getVarIndexBase()) ?
+ Builder::addParam :
+ static_cast<Index (*)(Function*, Name, WasmType)>(Builder::addVar);
+ if (paramType == i64) {
+ builderFunc(func, lowName, i32);
+ builderFunc(func, highName, i32);
+ indexMap[i] = newIdx;
+ newIdx += 2;
+ } else {
+ builderFunc(func, lowName, paramType);
+ indexMap[i] = newIdx++;
+ }
+ }
+ nextTemp = func->getNumLocals();
+ PostWalker<I64ToI32Lowering>::doWalkFunction(func);
+ }
+
+ void visitFunction(Function* func) {
+ if (func->result == i64) {
+ func->result = i32;
+ // body may not have out param if it ends with control flow
+ if (hasOutParam(func->body)) {
+ Index highBits = fetchOutParam(func->body);
+ Index lowBits = getTemp();
+ SetLocal* setLow = builder->makeSetLocal(
+ lowBits,
+ func->body
+ );
+ SetGlobal* setHigh = builder->makeSetGlobal(
+ highBitsGlobal,
+ builder->makeGetLocal(highBits, i32)
+ );
+ GetLocal* getLow = builder->makeGetLocal(lowBits, i32);
+ func->body = builder->blockify(setLow, setHigh, getLow);
+ freeTemp(highBits);
+ freeTemp(lowBits);
+ }
+ }
+ assert(freeTemps.size() == nextTemp - func->getNumLocals());
+ int idx = 0;
+ for (size_t i = func->getNumLocals(); i < nextTemp; i++) {
+ Name tmpName("i64toi32_i32$" + std::to_string(idx++));
+ builder->addVar(func, tmpName, i32);
+ }
+ }
+
+ void visitBlock(Block* curr) {
+ if (curr->list.size() == 0) return;
+ if (curr->type == i64) curr->type = i32;
+ auto highBitsIt = labelIndices.find(curr->name);
+ if (!hasOutParam(curr->list.back())) {
+ if (highBitsIt != labelIndices.end()) {
+ setOutParam(curr, highBitsIt->second);
+ }
+ return;
+ }
+ Index lastHighBits = fetchOutParam(curr->list.back());
+ if (highBitsIt == labelIndices.end() ||
+ highBitsIt->second == lastHighBits) {
+ setOutParam(curr, lastHighBits);
+ if (highBitsIt != labelIndices.end()) {
+ labelIndices.erase(highBitsIt);
+ }
+ return;
+ }
+ Index highBits = highBitsIt->second;
+ Index tmp = getTemp();
+ labelIndices.erase(highBitsIt);
+ SetLocal* setLow = builder->makeSetLocal(tmp, curr->list.back());
+ SetLocal* setHigh = builder->makeSetLocal(
+ highBits,
+ builder->makeGetLocal(lastHighBits, i32)
+ );
+ GetLocal* getLow = builder->makeGetLocal(tmp, i32);
+ curr->list.back() = builder->blockify(setLow, setHigh, getLow);
+ setOutParam(curr, highBits);
+ freeTemp(lastHighBits);
+ freeTemp(tmp);
+ }
+
+ // If and Select have identical code
+ template <typename T>
+ void visitBranching(T* curr) {
+ if (!hasOutParam(curr->ifTrue)) return;
+ assert(curr->ifFalse != nullptr && "Nullable ifFalse found");
+ Index highBits = fetchOutParam(curr->ifTrue);
+ Index falseBits = fetchOutParam(curr->ifFalse);
+ Index tmp = getTemp();
+ curr->type = i32;
+ curr->ifFalse = builder->blockify(
+ builder->makeSetLocal(tmp, curr->ifFalse),
+ builder->makeSetLocal(
+ highBits,
+ builder->makeGetLocal(falseBits, i32)
+ ),
+ builder->makeGetLocal(tmp, i32)
+ );
+ freeTemp(tmp);
+ freeTemp(falseBits);
+ setOutParam(curr, highBits);
+ }
+
+ void visitIf(If* curr) {
+ visitBranching<If>(curr);
+ }
+
+ void visitLoop(Loop* curr) {
+ assert(labelIndices.find(curr->name) == labelIndices.end());
+ if (curr->type != i64) return;
+ curr->type = i32;
+ setOutParam(curr, fetchOutParam(curr->body));
+ }
+
+ void visitBreak(Break* curr) {
+ if (!hasOutParam(curr->value)) return;
+ assert(curr->value != nullptr);
+ Index valHighBits = fetchOutParam(curr->value);
+ auto blockHighBitsIt = labelIndices.find(curr->name);
+ if (blockHighBitsIt == labelIndices.end()) {
+ labelIndices[curr->name] = valHighBits;
+ curr->type = i32;
+ return;
+ }
+ Index blockHighBits = blockHighBitsIt->second;
+ Index tmp = getTemp();
+ SetLocal* setLow = builder->makeSetLocal(tmp, curr->value);
+ SetLocal* setHigh = builder->makeSetLocal(
+ blockHighBits,
+ builder->makeGetLocal(valHighBits, i32)
+ );
+ curr->value = builder->makeGetLocal(tmp, i32);
+ curr->type = i32;
+ replaceCurrent(builder->blockify(setLow, setHigh, curr));
+ freeTemp(tmp);
+ }
+
+ void visitSwitch(Switch* curr) {
+ if (!hasOutParam(curr->value)) return;
+ Index outParam = fetchOutParam(curr->value);
+ Index tmp = getTemp();
+ bool didReuseOutParam = false;
+ Expression* result = curr;
+ std::vector<Name> targets;
+ auto processTarget = [&](Name target) -> Name {
+ auto labelIt = labelIndices.find(target);
+ if (labelIt == labelIndices.end() || labelIt->second == outParam) {
+ labelIndices[target] = outParam;
+ didReuseOutParam = true;
+ return target;
+ }
+ Index labelOutParam = labelIt->second;
+ Name newLabel("$i64toi32_" + std::string(target.c_str()));
+ result = builder->blockify(
+ builder->makeSetLocal(tmp, builder->makeBlock(newLabel, result)),
+ builder->makeSetLocal(
+ labelOutParam,
+ builder->makeGetLocal(outParam, i32)
+ ),
+ builder->makeGetLocal(tmp, i32)
+ );
+ assert(result->type == i32);
+ return newLabel;
+ };
+ for (auto target : curr->targets) {
+ targets.push_back(processTarget(target));
+ }
+ curr->default_ = processTarget(curr->default_);
+ replaceCurrent(result);
+ freeTemp(tmp);
+ if (!didReuseOutParam) {
+ freeTemp(outParam);
+ }
+ }
+
+ template <typename T>
+ using BuilderFunc = std::function<T*(std::vector<Expression*>&, WasmType)>;
+
+ template <typename T>
+ void visitGenericCall(T* curr, BuilderFunc<T> callBuilder) {
+ std::vector<Expression*> args;
+ for (auto* e : curr->operands) {
+ args.push_back(e);
+ if (hasOutParam(e)) {
+ Index argHighBits = fetchOutParam(e);
+ args.push_back(builder->makeGetLocal(argHighBits, i32));
+ freeTemp(argHighBits);
+ }
+ }
+ if (curr->type != i64) {
+ replaceCurrent(callBuilder(args, curr->type));
+ return;
+ }
+ Index lowBits = getTemp();
+ Index highBits = getTemp();
+ SetLocal* doCall = builder->makeSetLocal(
+ lowBits,
+ callBuilder(args, i32)
+ );
+ SetLocal* setHigh = builder->makeSetLocal(
+ highBits,
+ builder->makeGetGlobal(highBitsGlobal, i32)
+ );
+ GetLocal* getLow = builder->makeGetLocal(lowBits, i32);
+ Block* result = builder->blockify(doCall, setHigh, getLow);
+ freeTemp(lowBits);
+ setOutParam(result, highBits);
+ replaceCurrent(result);
+ }
+ void visitCall(Call* curr) {
+ visitGenericCall<Call>(
+ curr,
+ [&](std::vector<Expression*>& args, WasmType ty) {
+ return builder->makeCall(curr->target, args, ty);
+ }
+ );
+ }
+
+ void visitCallImport(CallImport* curr) {
+ // imports cannot contain i64s
+ return;
+ }
+
+ void visitCallIndirect(CallIndirect* curr) {
+ visitGenericCall<CallIndirect>(
+ curr,
+ [&](std::vector<Expression*>& args, WasmType ty) {
+ return builder->makeCallIndirect(
+ curr->fullType,
+ curr->target,
+ args,
+ ty
+ );
+ }
+ );
+ }
+
+ void visitGetLocal(GetLocal* curr) {
+ if (curr->type != i64) return;
+ curr->index = indexMap[curr->index];
+ curr->type = i32;
+ Index highBits = getTemp();
+ SetLocal *setHighBits = builder->makeSetLocal(
+ highBits,
+ builder->makeGetLocal(
+ curr->index + 1,
+ i32
+ )
+ );
+ Block* result = builder->blockify(setHighBits, curr);
+ replaceCurrent(result);
+ setOutParam(result, highBits);
+ }
+
+ void lowerTee(SetLocal* curr) {
+ Index highBits = fetchOutParam(curr->value);
+ Index tmp = getTemp();
+ curr->index = indexMap[curr->index];
+ curr->type = i32;
+ SetLocal* setLow = builder->makeSetLocal(tmp, curr);
+ SetLocal* setHigh = builder->makeSetLocal(
+ curr->index + 1,
+ builder->makeGetLocal(highBits, i32)
+ );
+ GetLocal* getLow = builder->makeGetLocal(tmp, i32);
+ Block* result = builder->blockify(setLow, setHigh, getLow);
+ replaceCurrent(result);
+ setOutParam(result, highBits);
+ freeTemp(tmp);
+ }
+
+ void visitSetLocal(SetLocal* curr) {
+ if (!hasOutParam(curr->value)) return;
+ if (curr->isTee()) {
+ lowerTee(curr);
+ return;
+ }
+ Index highBits = fetchOutParam(curr->value);
+ curr->index = indexMap[curr->index];
+ SetLocal* setHigh = builder->makeSetLocal(
+ curr->index + 1,
+ builder->makeGetLocal(highBits, i32)
+ );
+ Block* result = builder->blockify(curr, setHigh);
+ replaceCurrent(result);
+ freeTemp(highBits);
+ }
+
+ void visitGetGlobal(GetGlobal* curr) {
+ assert(false && "GetGlobal not implemented");
+ }
+
+ void visitSetGlobal(SetGlobal* curr) {
+ assert(false && "SetGlobal not implemented");
+ }
+
+ void visitLoad(Load* curr) {
+ if (curr->type != i64) return;
+ assert(!curr->isAtomic && "atomic load not implemented");
+ Index highBits = getTemp();
+ Index ptrTemp = getTemp();
+ SetLocal* setPtr = builder->makeSetLocal(ptrTemp, curr->ptr);
+ SetLocal* loadHigh;
+ if (curr->bytes == 8) {
+ loadHigh = builder->makeSetLocal(
+ highBits,
+ builder->makeLoad(
+ 4,
+ curr->signed_,
+ curr->offset + 4,
+ 1,
+ builder->makeGetLocal(ptrTemp, i32),
+ i32
+ )
+ );
+ } else {
+ loadHigh = builder->makeSetLocal(
+ highBits,
+ builder->makeConst(Literal(int32_t(0)))
+ );
+ }
+ curr->type = i32;
+ curr->bytes = std::min(curr->bytes, uint8_t(4));
+ curr->align = std::min(uint32_t(curr->align), uint32_t(4));
+ curr->ptr = builder->makeGetLocal(ptrTemp, i32);
+ Block* result = builder->blockify(setPtr, loadHigh, curr);
+ replaceCurrent(result);
+ setOutParam(result, highBits);
+ freeTemp(ptrTemp);
+ }
+
+ void visitStore(Store* curr) {
+ if (!hasOutParam(curr->value)) return;
+ assert(curr->offset + 4 > curr->offset);
+ assert(!curr->isAtomic && "atomic store not implemented");
+ Index highBits = fetchOutParam(curr->value);
+ uint8_t bytes = curr->bytes;
+ curr->bytes = std::min(curr->bytes, uint8_t(4));
+ curr->align = std::min(uint32_t(curr->align), uint32_t(4));
+ curr->valueType = i32;
+ if (bytes == 8) {
+ Index ptrTemp = getTemp();
+ SetLocal* setPtr = builder->makeSetLocal(ptrTemp, curr->ptr);
+ curr->ptr = builder->makeGetLocal(ptrTemp, i32);
+ Store* storeHigh = builder->makeStore(
+ 4,
+ curr->offset + 4,
+ 1,
+ builder->makeGetLocal(ptrTemp, i32),
+ builder->makeGetLocal(highBits, i32),
+ i32
+ );
+ replaceCurrent(builder->blockify(setPtr, curr, storeHigh));
+ freeTemp(ptrTemp);
+ }
+ freeTemp(highBits);
+ }
+
+ void visitAtomicRMW(AtomicRMW* curr) {
+ assert(false && "AtomicRMW not implemented");
+ }
+
+ void visitAtomicCmpxchg(AtomicCmpxchg* curr) {
+ assert(false && "AtomicCmpxchg not implemented");
+ }
+
+ void visitConst(Const* curr) {
+ if (curr->type != i64) return;
+ Index highBits = getTemp();
+ Const* lowVal = builder->makeConst(
+ Literal(int32_t(curr->value.geti64() & 0xffffffff))
+ );
+ SetLocal* setHigh = builder->makeSetLocal(
+ highBits,
+ builder->makeConst(
+ Literal(int32_t(uint64_t(curr->value.geti64()) >> 32))
+ )
+ );
+ Block* result = builder->blockify(setHigh, lowVal);
+ setOutParam(result, highBits);
+ replaceCurrent(result);
+ }
+
+ void lowerEqZInt64(Unary* curr) {
+ Index highBits = fetchOutParam(curr->value);
+ replaceCurrent(
+ builder->makeBinary(
+ AndInt32,
+ builder->makeUnary(EqZInt32, builder->makeGetLocal(highBits, i32)),
+ builder->makeUnary(EqZInt32, curr->value)
+ )
+ );
+ freeTemp(highBits);
+ }
+
+ void lowerExtendUInt32(Unary* curr) {
+ Index highBits = getTemp();
+ Block* result = builder->blockify(
+ builder->makeSetLocal(highBits, builder->makeConst(Literal(int32_t(0)))),
+ curr->value
+ );
+ setOutParam(result, highBits);
+ replaceCurrent(result);
+ }
+
+ void lowerWrapInt64(Unary* curr) {
+ freeTemp(fetchOutParam(curr->value));
+ replaceCurrent(curr->value);
+ }
+
+ bool unaryNeedsLowering(UnaryOp op) {
+ switch (op) {
+ case ClzInt64:
+ case CtzInt64:
+ case PopcntInt64:
+ case EqZInt64:
+ case ExtendSInt32:
+ case ExtendUInt32:
+ case WrapInt64:
+ case TruncSFloat32ToInt64:
+ case TruncUFloat32ToInt64:
+ case TruncSFloat64ToInt64:
+ case TruncUFloat64ToInt64:
+ case ReinterpretFloat64:
+ case ConvertSInt64ToFloat32:
+ case ConvertSInt64ToFloat64:
+ case ConvertUInt64ToFloat32:
+ case ConvertUInt64ToFloat64:
+ case ReinterpretInt64: return true;
+ default: return false;
+ }
+ }
+
+ void visitUnary(Unary* curr) {
+ if (!unaryNeedsLowering(curr->op)) return;
+ if (curr->type == unreachable || curr->value->type == unreachable) {
+ assert(!hasOutParam(curr->value));
+ replaceCurrent(curr->value);
+ return;
+ }
+ assert(hasOutParam(curr->value) || curr->type == i64);
+ switch (curr->op) {
+ case ClzInt64:
+ case CtzInt64:
+ case PopcntInt64: goto err;
+ case EqZInt64: lowerEqZInt64(curr); break;
+ case ExtendSInt32: goto err;
+ case ExtendUInt32: lowerExtendUInt32(curr); break;
+ case WrapInt64: lowerWrapInt64(curr); break;
+ case TruncSFloat32ToInt64:
+ case TruncUFloat32ToInt64:
+ case TruncSFloat64ToInt64:
+ case TruncUFloat64ToInt64:
+ case ReinterpretFloat64:
+ case ConvertSInt64ToFloat32:
+ case ConvertSInt64ToFloat64:
+ case ConvertUInt64ToFloat32:
+ case ConvertUInt64ToFloat64:
+ case ReinterpretInt64:
+ err: default:
+ std::cerr << "Unhandled unary operator: " << curr->op << std::endl;
+ abort();
+ }
+ }
+
+ Block* lowerAdd(Block* result, Index leftLow, Index leftHigh,
+ Index rightLow, Index rightHigh) {
+ SetLocal* addLow = builder->makeSetLocal(
+ leftHigh,
+ builder->makeBinary(
+ AddInt32,
+ builder->makeGetLocal(leftLow, i32),
+ builder->makeGetLocal(rightLow, i32)
+ )
+ );
+ SetLocal* addHigh = builder->makeSetLocal(
+ rightHigh,
+ builder->makeBinary(
+ AddInt32,
+ builder->makeGetLocal(leftHigh, i32),
+ builder->makeGetLocal(rightHigh, i32)
+ )
+ );
+ SetLocal* carryBit = builder->makeSetLocal(
+ rightHigh,
+ builder->makeBinary(
+ AddInt32,
+ builder->makeGetLocal(rightHigh, i32),
+ builder->makeConst(Literal(int32_t(1)))
+ )
+ );
+ If* checkOverflow = builder->makeIf(
+ builder->makeBinary(
+ LtUInt32,
+ builder->makeGetLocal(leftLow, i32),
+ builder->makeGetLocal(rightLow, i32)
+ ),
+ carryBit
+ );
+ GetLocal* getLow = builder->makeGetLocal(leftHigh, i32);
+ result = builder->blockify(result, addLow, addHigh, checkOverflow, getLow);
+ freeTemp(leftLow);
+ freeTemp(leftHigh);
+ freeTemp(rightLow);
+ setOutParam(result, rightHigh);
+ return result;
+ }
+
+ Block* lowerMul(Block* result, Index leftLow, Index leftHigh, Index rightLow,
+ Index rightHigh) {
+ // high bits = ll*rh + lh*rl + ll1*rl1 + (ll0*rl1)>>16 + (ll1*rl0)>>16
+ // low bits = ll*rl
+ Index leftLow0 = getTemp();
+ Index leftLow1 = getTemp();
+ Index rightLow0 = getTemp();
+ Index rightLow1 = getTemp();
+ SetLocal* setLL0 = builder->makeSetLocal(
+ leftLow0,
+ builder->makeBinary(
+ AndInt32,
+ builder->makeGetLocal(leftLow, i32),
+ builder->makeConst(Literal(int32_t(0xffff)))
+ )
+ );
+ SetLocal* setLL1 = builder->makeSetLocal(
+ leftLow1,
+ builder->makeBinary(
+ ShrUInt32,
+ builder->makeGetLocal(leftLow, i32),
+ builder->makeConst(Literal(int32_t(16)))
+ )
+ );
+ SetLocal* setRL0 = builder->makeSetLocal(
+ rightLow0,
+ builder->makeBinary(
+ AndInt32,
+ builder->makeGetLocal(rightLow, i32),
+ builder->makeConst(Literal(int32_t(0xffff)))
+ )
+ );
+ SetLocal* setRL1 = builder->makeSetLocal(
+ rightLow1,
+ builder->makeBinary(
+ ShrUInt32,
+ builder->makeGetLocal(rightLow, i32),
+ builder->makeConst(Literal(int32_t(16)))
+ )
+ );
+ SetLocal* setLLRH = builder->makeSetLocal(
+ rightHigh,
+ builder->makeBinary(
+ MulInt32,
+ builder->makeGetLocal(leftLow, i32),
+ builder->makeGetLocal(rightHigh, i32)
+ )
+ );
+ auto addToHighBits = [&](Expression* expr) -> SetLocal* {
+ return builder->makeSetLocal(
+ rightHigh,
+ builder->makeBinary(
+ AddInt32,
+ builder->makeGetLocal(rightHigh, i32),
+ expr
+ )
+ );
+ };
+ SetLocal* addLHRL = addToHighBits(
+ builder->makeBinary(
+ MulInt32,
+ builder->makeGetLocal(leftHigh, i32),
+ builder->makeGetLocal(rightLow, i32)
+ )
+ );
+ SetLocal* addLL1RL1 = addToHighBits(
+ builder->makeBinary(
+ MulInt32,
+ builder->makeGetLocal(leftLow1, i32),
+ builder->makeGetLocal(rightLow1, i32)
+ )
+ );
+ SetLocal* addLL0RL1 = addToHighBits(
+ builder->makeBinary(
+ ShrUInt32,
+ builder->makeBinary(
+ MulInt32,
+ builder->makeGetLocal(leftLow0, i32),
+ builder->makeGetLocal(rightLow1, i32)
+ ),
+ builder->makeConst(Literal(int32_t(16)))
+ )
+ );
+ SetLocal* addLL1RL0 = addToHighBits(
+ builder->makeBinary(
+ ShrUInt32,
+ builder->makeBinary(
+ MulInt32,
+ builder->makeGetLocal(leftLow1, i32),
+ builder->makeGetLocal(rightLow0, i32)
+ ),
+ builder->makeConst(Literal(int32_t(16)))
+ )
+ );
+ Binary* getLow = builder->makeBinary(
+ MulInt32,
+ builder->makeGetLocal(leftLow, i32),
+ builder->makeGetLocal(rightLow, i32)
+ );
+ result = builder->blockify(
+ result,
+ setLL0,
+ setLL1,
+ setRL0,
+ setRL1,
+ setLLRH,
+ addLHRL,
+ addLL1RL1,
+ addLL0RL1,
+ addLL1RL0,
+ getLow
+ );
+ freeTemp(leftLow0);
+ freeTemp(leftLow1);
+ freeTemp(rightLow0);
+ freeTemp(rightLow1);
+ freeTemp(leftLow);
+ freeTemp(leftHigh);
+ freeTemp(rightLow);
+ setOutParam(result, rightHigh);
+ return result;
+ }
+
+ Block* lowerBitwise(BinaryOp op, Block* result, Index leftLow, Index leftHigh,
+ Index rightLow, Index rightHigh) {
+ BinaryOp op32;
+ switch (op) {
+ case AndInt64: op32 = AndInt32; break;
+ case OrInt64: op32 = OrInt32; break;
+ case XorInt64: op32 = XorInt32; break;
+ default: abort();
+ }
+ result = builder->blockify(
+ result,
+ builder->makeSetLocal(
+ rightHigh,
+ builder->makeBinary(
+ op32,
+ builder->makeGetLocal(leftHigh, i32),
+ builder->makeGetLocal(rightHigh, i32)
+ )
+ ),
+ builder->makeBinary(
+ op32,
+ builder->makeGetLocal(leftLow, i32),
+ builder->makeGetLocal(rightLow, i32)
+ )
+ );
+ freeTemp(leftLow);
+ freeTemp(leftHigh);
+ freeTemp(rightLow);
+ setOutParam(result, rightHigh);
+ return result;
+ }
+
+ Block* makeLargeShl(Index highBits, Index leftLow, Index shift) {
+ return builder->blockify(
+ builder->makeSetLocal(
+ highBits,
+ builder->makeBinary(
+ ShlInt32,
+ builder->makeGetLocal(leftLow, i32),
+ builder->makeGetLocal(shift, i32)
+ )
+ ),
+ builder->makeConst(Literal(int32_t(0)))
+ );
+ }
+
+ Block* makeLargeShrU(Index highBits, Index leftHigh, Index shift) {
+ return builder->blockify(
+ builder->makeSetLocal(highBits, builder->makeConst(Literal(int32_t(0)))),
+ builder->makeBinary(
+ ShrUInt32,
+ builder->makeGetLocal(leftHigh, i32),
+ builder->makeGetLocal(shift, i32)
+ )
+ );
+ }
+
+ Block* makeSmallShl(Index highBits, Index leftLow, Index leftHigh,
+ Index shift, Binary* shiftMask, Binary* widthLessShift) {
+ Binary* shiftedInBits = builder->makeBinary(
+ AndInt32,
+ shiftMask,
+ builder->makeBinary(
+ ShrUInt32,
+ builder->makeGetLocal(leftLow, i32),
+ widthLessShift
+ )
+ );
+ Binary* shiftHigh = builder->makeBinary(
+ ShlInt32,
+ builder->makeGetLocal(leftHigh, i32),
+ builder->makeGetLocal(shift, i32)
+ );
+ return builder->blockify(
+ builder->makeSetLocal(
+ highBits,
+ builder->makeBinary(OrInt32, shiftedInBits, shiftHigh)
+ ),
+ builder->makeBinary(
+ ShlInt32,
+ builder->makeGetLocal(leftLow, i32),
+ builder->makeGetLocal(shift, i32)
+ )
+ );
+ }
+
+ Block* makeSmallShrU(Index highBits, Index leftLow, Index leftHigh,
+ Index shift, Binary* shiftMask, Binary* widthLessShift) {
+ Binary* shiftedInBits = builder->makeBinary(
+ ShlInt32,
+ builder->makeBinary(
+ AndInt32,
+ shiftMask,
+ builder->makeGetLocal(leftHigh, i32)
+ ),
+ widthLessShift
+ );
+ Binary* shiftLow = builder->makeBinary(
+ ShrUInt32,
+ builder->makeGetLocal(leftLow, i32),
+ builder->makeGetLocal(shift, i32)
+ );
+ return builder->blockify(
+ builder->makeSetLocal(
+ highBits,
+ builder->makeBinary(
+ ShrUInt32,
+ builder->makeGetLocal(leftHigh, i32),
+ builder->makeGetLocal(shift, i32)
+ )
+ ),
+ builder->makeBinary(OrInt32, shiftedInBits, shiftLow)
+ );
+ }
+
+ Block* lowerShU(BinaryOp op, Block* result, Index leftLow, Index leftHigh,
+ Index rightLow, Index rightHigh) {
+ assert(op == ShlInt64 || op == ShrUInt64);
+ // shift left lowered as:
+ // if 32 <= rightLow % 64:
+ // high = leftLow << k; low = 0
+ // else:
+ // high = (((1 << k) - 1) & (leftLow >> (32 - k))) | (leftHigh << k);
+ // low = leftLow << k
+ // where k = shift % 32. shift right is similar.
+ Index shift = getTemp();
+ SetLocal* setShift = builder->makeSetLocal(
+ shift,
+ builder->makeBinary(
+ AndInt32,
+ builder->makeGetLocal(rightLow, i32),
+ builder->makeConst(Literal(int32_t(32 - 1)))
+ )
+ );
+ Binary* isLargeShift = builder->makeBinary(
+ LeUInt32,
+ builder->makeConst(Literal(int32_t(32))),
+ builder->makeBinary(
+ AndInt32,
+ builder->makeGetLocal(rightLow, i32),
+ builder->makeConst(Literal(int32_t(64 - 1)))
+ )
+ );
+ Block* largeShiftBlock;
+ switch (op) {
+ case ShlInt64:
+ largeShiftBlock = makeLargeShl(rightHigh, leftLow, shift); break;
+ case ShrUInt64:
+ largeShiftBlock = makeLargeShrU(rightHigh, leftHigh, shift); break;
+ default: abort();
+ }
+ Binary* shiftMask = builder->makeBinary(
+ SubInt32,
+ builder->makeBinary(
+ ShlInt32,
+ builder->makeConst(Literal(int32_t(1))),
+ builder->makeGetLocal(shift, i32)
+ ),
+ builder->makeConst(Literal(int32_t(1)))
+ );
+ Binary* widthLessShift = builder->makeBinary(
+ SubInt32,
+ builder->makeConst(Literal(int32_t(32))),
+ builder->makeGetLocal(shift, i32)
+ );
+ Block* smallShiftBlock;
+ switch(op) {
+ case ShlInt64: {
+ smallShiftBlock = makeSmallShl(rightHigh, leftLow, leftHigh,
+ shift, shiftMask, widthLessShift);
+ break;
+ }
+ case ShrUInt64: {
+ smallShiftBlock = makeSmallShrU(rightHigh, leftLow, leftHigh,
+ shift, shiftMask, widthLessShift);
+ break;
+ }
+ default: abort();
+ }
+ If* ifLargeShift = builder->makeIf(
+ isLargeShift,
+ largeShiftBlock,
+ smallShiftBlock
+ );
+ result = builder->blockify(result, setShift, ifLargeShift);
+ freeTemp(shift);
+ freeTemp(leftLow);
+ freeTemp(leftHigh);
+ freeTemp(rightLow);
+ setOutParam(result, rightHigh);
+ return result;
+ }
+
+ Block* lowerEq(Block* result, Index leftLow, Index leftHigh,
+ Index rightLow, Index rightHigh) {
+ freeTemp(leftLow);
+ freeTemp(leftHigh);
+ freeTemp(rightLow);
+ freeTemp(rightHigh);
+ return builder->blockify(
+ result,
+ builder->makeBinary(
+ AndInt32,
+ builder->makeBinary(
+ EqInt32,
+ builder->makeGetLocal(leftLow, i32),
+ builder->makeGetLocal(rightLow, i32)
+ ),
+ builder->makeBinary(
+ EqInt32,
+ builder->makeGetLocal(leftHigh, i32),
+ builder->makeGetLocal(rightHigh, i32)
+ )
+ )
+ );
+ }
+
+ Block* lowerNe(Block* result, Index leftLow, Index leftHigh,
+ Index rightLow, Index rightHigh) {
+ freeTemp(leftLow);
+ freeTemp(leftHigh);
+ freeTemp(rightLow);
+ freeTemp(rightHigh);
+ return builder->blockify(
+ result,
+ builder->makeBinary(
+ OrInt32,
+ builder->makeBinary(
+ NeInt32,
+ builder->makeGetLocal(leftLow, i32),
+ builder->makeGetLocal(rightLow, i32)
+ ),
+ builder->makeBinary(
+ NeInt32,
+ builder->makeGetLocal(leftHigh, i32),
+ builder->makeGetLocal(rightHigh, i32)
+ )
+ )
+ );
+ }
+
+ Block* lowerUComp(BinaryOp op, Block* result, Index leftLow, Index leftHigh,
+ Index rightLow, Index rightHigh) {
+ freeTemp(leftLow);
+ freeTemp(leftHigh);
+ freeTemp(rightLow);
+ freeTemp(rightHigh);
+ BinaryOp highOp, lowOp;
+ switch (op) {
+ case LtUInt64: highOp = LtUInt32; lowOp = LtUInt32; break;
+ case LeUInt64: highOp = LtUInt32; lowOp = LeUInt32; break;
+ case GtUInt64: highOp = GtUInt32; lowOp = GtUInt32; break;
+ case GeUInt64: highOp = GtUInt32; lowOp = GeUInt32; break;
+ default: abort();
+ }
+ Binary* compHigh = builder->makeBinary(
+ highOp,
+ builder->makeGetLocal(leftHigh, i32),
+ builder->makeGetLocal(rightHigh, i32)
+ );
+ Binary* eqHigh = builder->makeBinary(
+ EqInt32,
+ builder->makeGetLocal(leftHigh, i32),
+ builder->makeGetLocal(rightHigh, i32)
+ );
+ Binary* compLow = builder->makeBinary(
+ lowOp,
+ builder->makeGetLocal(leftLow, i32),
+ builder->makeGetLocal(rightLow, i32)
+ );
+ return builder->blockify(
+ result,
+ builder->makeBinary(
+ OrInt32,
+ compHigh,
+ builder->makeBinary(AndInt32, eqHigh, compLow)
+ )
+ );
+ }
+
+ bool binaryNeedsLowering(BinaryOp op) {
+ switch (op) {
+ case AddInt64:
+ case SubInt64:
+ case MulInt64:
+ case DivSInt64:
+ case DivUInt64:
+ case RemSInt64:
+ case RemUInt64:
+ case AndInt64:
+ case OrInt64:
+ case XorInt64:
+ case ShlInt64:
+ case ShrUInt64:
+ case ShrSInt64:
+ case RotLInt64:
+ case RotRInt64:
+ case EqInt64:
+ case NeInt64:
+ case LtSInt64:
+ case LtUInt64:
+ case LeSInt64:
+ case LeUInt64:
+ case GtSInt64:
+ case GtUInt64:
+ case GeSInt64:
+ case GeUInt64: return true;
+ default: return false;
+ }
+ }
+
+ void visitBinary(Binary* curr) {
+ if (!binaryNeedsLowering(curr->op)) return;
+ if (!hasOutParam(curr->left)) {
+ // left unreachable, replace self with left
+ replaceCurrent(curr->left);
+ if (hasOutParam(curr->right)) {
+ freeTemp(fetchOutParam(curr->right));
+ }
+ return;
+ }
+ if (!hasOutParam(curr->right)) {
+ // right unreachable, replace self with left then right
+ replaceCurrent(
+ builder->blockify(builder->makeDrop(curr->left), curr->right)
+ );
+ freeTemp(fetchOutParam(curr->left));
+ return;
+ }
+ // left and right reachable, lower normally
+ Index leftLow = getTemp();
+ Index leftHigh = fetchOutParam(curr->left);
+ Index rightLow = getTemp();
+ Index rightHigh = fetchOutParam(curr->right);
+ SetLocal* setRight = builder->makeSetLocal(rightLow, curr->right);
+ SetLocal* setLeft = builder->makeSetLocal(leftLow, curr->left);
+ Block* result = builder->blockify(setLeft, setRight);
+ switch (curr->op) {
+ case AddInt64: {
+ replaceCurrent(lowerAdd(result, leftLow, leftHigh, rightLow,
+ rightHigh));
+ break;
+ }
+ case SubInt64: goto err;
+ case MulInt64: {
+ replaceCurrent(lowerMul(result, leftLow, leftHigh, rightLow,
+ rightHigh));
+ break;
+ }
+ case DivSInt64:
+ case DivUInt64:
+ case RemSInt64:
+ case RemUInt64: goto err;
+ case AndInt64:
+ case OrInt64:
+ case XorInt64: {
+ replaceCurrent(lowerBitwise(curr->op, result, leftLow, leftHigh,
+ rightLow, rightHigh));
+ break;
+ }
+ case ShlInt64:
+ case ShrUInt64: {
+ replaceCurrent(lowerShU(curr->op, result, leftLow, leftHigh, rightLow,
+ rightHigh));
+ break;
+ }
+ case ShrSInt64:
+ case RotLInt64:
+ case RotRInt64: goto err;
+ case EqInt64: {
+ replaceCurrent(lowerEq(result, leftLow, leftHigh, rightLow, rightHigh));
+ break;
+ }
+ case NeInt64: {
+ replaceCurrent(lowerNe(result, leftLow, leftHigh, rightLow, rightHigh));
+ break;
+ }
+ case LtSInt64:
+ case LeSInt64:
+ case GtSInt64:
+ case GeSInt64: goto err;
+ case LtUInt64:
+ case LeUInt64:
+ case GtUInt64:
+ case GeUInt64: {
+ replaceCurrent(lowerUComp(curr->op, result, leftLow, leftHigh, rightLow,
+ rightHigh));
+ break;
+ }
+ err: default: {
+ std::cerr << "Unhandled binary op " << curr->op << std::endl;
+ abort();
+ }
+ }
+ }
+
+ void visitSelect(Select* curr) {
+ visitBranching<Select>(curr);
+ }
+
+ void visitDrop(Drop* curr) {
+ if (!hasOutParam(curr->value)) return;
+ freeTemp(fetchOutParam(curr->value));
+ }
+
+ void visitReturn(Return* curr) {
+ if (!hasOutParam(curr->value)) return;
+ Index lowBits = getTemp();
+ Index highBits = fetchOutParam(curr->value);
+ SetLocal* setLow = builder->makeSetLocal(lowBits, curr->value);
+ SetGlobal* setHigh = builder->makeSetGlobal(
+ highBitsGlobal,
+ builder->makeGetLocal(highBits, i32)
+ );
+ curr->value = builder->makeGetLocal(lowBits, i32);
+ Block* result = builder->blockify(setLow, setHigh, curr);
+ replaceCurrent(result);
+ freeTemp(lowBits);
+ freeTemp(highBits);
+ }
+
+private:
+ std::unique_ptr<Builder> builder;
+ std::unordered_map<Index, Index> indexMap;
+ std::unordered_map<Expression*, Index> returnIndices;
+ std::unordered_map<Name, Index> labelIndices;
+ std::vector<Index> freeTemps;
+ Index nextTemp;
+
+ // TODO: RAII for temp var allocation
+ Index getTemp() {
+ Index ret;
+ if (freeTemps.size() > 0) {
+ ret = freeTemps.back();
+ freeTemps.pop_back();
+ } else {
+ ret = nextTemp++;
+ }
+ return ret;
+ }
+
+ void freeTemp(Index t) {
+ assert(std::find(freeTemps.begin(), freeTemps.end(), t) == freeTemps.end());
+ freeTemps.push_back(t);
+ }
+
+ bool hasOutParam(Expression* e) {
+ return returnIndices.find(e) != returnIndices.end();
+ }
+
+ void setOutParam(Expression* e, Index idx) {
+ returnIndices[e] = idx;
+ }
+
+ Index fetchOutParam(Expression* e) {
+ assert(returnIndices.find(e) != returnIndices.end());
+ Index ret = returnIndices[e];
+ returnIndices.erase(e);
+ return ret;
+ }
+};
+
+Name I64ToI32Lowering::highBitsGlobal("i64toi32_i32$HIGH_BITS");
+
+Pass *createI64ToI32LoweringPass() {
+ return new I64ToI32Lowering();
+}
+
+}