diff options
author | Thomas Lively <7121787+tlively@users.noreply.github.com> | 2017-09-01 14:26:01 -0400 |
---|---|---|
committer | Alon Zakai <alonzakai@gmail.com> | 2017-09-01 11:26:01 -0700 |
commit | b013f744e3d70effd9be348cbde7fb93f0a16c6a (patch) | |
tree | 3b122293005d3370c931175eed92ad61e9dfd851 /src | |
parent | b1e8b1b515b2a1d0264975abc4de39c8044f7195 (diff) | |
download | binaryen-b013f744e3d70effd9be348cbde7fb93f0a16c6a.tar.gz binaryen-b013f744e3d70effd9be348cbde7fb93f0a16c6a.tar.bz2 binaryen-b013f744e3d70effd9be348cbde7fb93f0a16c6a.zip |
i64 to i32 lowering for wasm2asm (#1134)
Diffstat (limited to 'src')
-rw-r--r-- | src/emscripten-optimizer/simple_ast.h | 8 | ||||
-rw-r--r-- | src/passes/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/passes/I64ToI32Lowering.cpp | 1197 | ||||
-rw-r--r-- | src/passes/pass.cpp | 1 | ||||
-rw-r--r-- | src/passes/passes.h | 1 | ||||
-rw-r--r-- | src/wasm-builder.h | 9 | ||||
-rw-r--r-- | src/wasm-traversal.h | 1 | ||||
-rw-r--r-- | src/wasm.h | 1 | ||||
-rw-r--r-- | src/wasm2asm.h | 334 |
9 files changed, 1435 insertions, 118 deletions
diff --git a/src/emscripten-optimizer/simple_ast.h b/src/emscripten-optimizer/simple_ast.h index 870765323..62bf975f0 100644 --- a/src/emscripten-optimizer/simple_ast.h +++ b/src/emscripten-optimizer/simple_ast.h @@ -1006,7 +1006,7 @@ struct JSPrinter { if (childPrecedence < parentPrecedence) return false; // definitely cool // equal precedence, so associativity (rtl/ltr) is what matters // (except for some exceptions, where multiple operators can combine into confusion) - if (parent[0] == UNARY_PREFIX) { + if (parent->isArray() && parent[0] == UNARY_PREFIX) { assert(child[0] == UNARY_PREFIX); if ((parent[1] == PLUS || parent[1] == MINUS) && child[1] == parent[1]) { // cannot emit ++x when we mean +(+x) @@ -1036,8 +1036,10 @@ struct JSPrinter { } void printUnaryPrefix(Ref node) { - if (finalize && node[1] == PLUS && (node[2]->isNumber() || - (node[2][0] == UNARY_PREFIX && node[2][1] == MINUS && node[2][2]->isNumber()))) { + if (finalize && node[1] == PLUS && + (node[2]->isNumber() || + (node[2]->isArray() && node[2][0] == UNARY_PREFIX && + node[2][1] == MINUS && node[2][2]->isNumber()))) { // emit a finalized number int last = used; print(node[2]); diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt index a575d8b27..edbf945cf 100644 --- a/src/passes/CMakeLists.txt +++ b/src/passes/CMakeLists.txt @@ -11,6 +11,7 @@ SET(passes_SOURCES LegalizeJSInterface.cpp LocalCSE.cpp LogExecution.cpp + I64ToI32Lowering.cpp InstrumentLocals.cpp InstrumentMemory.cpp MemoryPacking.cpp diff --git a/src/passes/I64ToI32Lowering.cpp b/src/passes/I64ToI32Lowering.cpp new file mode 100644 index 000000000..65871ca6f --- /dev/null +++ b/src/passes/I64ToI32Lowering.cpp @@ -0,0 +1,1197 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +// Lowers i64s to i32s by splitting variables and arguments +// into pairs of i32s. i64 return values are lowered by +// returning the low half and storing the high half into a +// global. +// + +#include <algorithm> +#include "wasm.h" +#include "pass.h" +#include "emscripten-optimizer/istring.h" +#include "support/name.h" +#include "wasm-builder.h" + + +namespace wasm { + +static Name makeHighName(Name n) { + return Name( + cashew::IString((std::string(n.c_str()) + "$hi").c_str(), false) + ); +} + +struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> { + static Name highBitsGlobal; + + // false since function types need to be lowered + // TODO: allow module-level transformations in parallel passes + bool isFunctionParallel() override { return false; } + + Pass* create() override { + return new I64ToI32Lowering; + } + + void doWalkModule(Module* module) { + if (!builder) builder = make_unique<Builder>(*module); + // add new globals for high bits + for (size_t i = 0, globals = module->globals.size(); i < globals; ++i) { + auto& curr = module->globals[i]; + if (curr->type != i64) continue; + curr->type = i32; + auto* high = new Global(*curr); + high->name = makeHighName(curr->name); + module->addGlobal(high); + } + PostWalker<I64ToI32Lowering>::doWalkModule(module); + } + + void visitFunctionType(FunctionType* curr) { + std::vector<WasmType> params; + for (auto t : curr->params) { + if (t == i64) { + params.push_back(i32); + params.push_back(i32); + } else { + params.push_back(t); + } + } + std::swap(params, curr->params); + if (curr->result == i64) { + curr->result = i32; + } + } + + void doWalkFunction(Function* func) { + // create builder here if this is first entry to module for this object + if (!builder) builder = make_unique<Builder>(*getModule()); + indexMap.clear(); + returnIndices.clear(); + labelIndices.clear(); + freeTemps.clear(); + Function oldFunc(*func); + func->params.clear(); + func->vars.clear(); + func->localNames.clear(); + func->localIndices.clear(); + Index newIdx = 0; + for (Index i = 0; i < oldFunc.getNumLocals(); ++i) { + assert(oldFunc.hasLocalName(i)); + Name lowName = oldFunc.getLocalName(i); + Name highName = makeHighName(lowName); + WasmType paramType = oldFunc.getLocalType(i); + auto builderFunc = (i < oldFunc.getVarIndexBase()) ? + Builder::addParam : + static_cast<Index (*)(Function*, Name, WasmType)>(Builder::addVar); + if (paramType == i64) { + builderFunc(func, lowName, i32); + builderFunc(func, highName, i32); + indexMap[i] = newIdx; + newIdx += 2; + } else { + builderFunc(func, lowName, paramType); + indexMap[i] = newIdx++; + } + } + nextTemp = func->getNumLocals(); + PostWalker<I64ToI32Lowering>::doWalkFunction(func); + } + + void visitFunction(Function* func) { + if (func->result == i64) { + func->result = i32; + // body may not have out param if it ends with control flow + if (hasOutParam(func->body)) { + Index highBits = fetchOutParam(func->body); + Index lowBits = getTemp(); + SetLocal* setLow = builder->makeSetLocal( + lowBits, + func->body + ); + SetGlobal* setHigh = builder->makeSetGlobal( + highBitsGlobal, + builder->makeGetLocal(highBits, i32) + ); + GetLocal* getLow = builder->makeGetLocal(lowBits, i32); + func->body = builder->blockify(setLow, setHigh, getLow); + freeTemp(highBits); + freeTemp(lowBits); + } + } + assert(freeTemps.size() == nextTemp - func->getNumLocals()); + int idx = 0; + for (size_t i = func->getNumLocals(); i < nextTemp; i++) { + Name tmpName("i64toi32_i32$" + std::to_string(idx++)); + builder->addVar(func, tmpName, i32); + } + } + + void visitBlock(Block* curr) { + if (curr->list.size() == 0) return; + if (curr->type == i64) curr->type = i32; + auto highBitsIt = labelIndices.find(curr->name); + if (!hasOutParam(curr->list.back())) { + if (highBitsIt != labelIndices.end()) { + setOutParam(curr, highBitsIt->second); + } + return; + } + Index lastHighBits = fetchOutParam(curr->list.back()); + if (highBitsIt == labelIndices.end() || + highBitsIt->second == lastHighBits) { + setOutParam(curr, lastHighBits); + if (highBitsIt != labelIndices.end()) { + labelIndices.erase(highBitsIt); + } + return; + } + Index highBits = highBitsIt->second; + Index tmp = getTemp(); + labelIndices.erase(highBitsIt); + SetLocal* setLow = builder->makeSetLocal(tmp, curr->list.back()); + SetLocal* setHigh = builder->makeSetLocal( + highBits, + builder->makeGetLocal(lastHighBits, i32) + ); + GetLocal* getLow = builder->makeGetLocal(tmp, i32); + curr->list.back() = builder->blockify(setLow, setHigh, getLow); + setOutParam(curr, highBits); + freeTemp(lastHighBits); + freeTemp(tmp); + } + + // If and Select have identical code + template <typename T> + void visitBranching(T* curr) { + if (!hasOutParam(curr->ifTrue)) return; + assert(curr->ifFalse != nullptr && "Nullable ifFalse found"); + Index highBits = fetchOutParam(curr->ifTrue); + Index falseBits = fetchOutParam(curr->ifFalse); + Index tmp = getTemp(); + curr->type = i32; + curr->ifFalse = builder->blockify( + builder->makeSetLocal(tmp, curr->ifFalse), + builder->makeSetLocal( + highBits, + builder->makeGetLocal(falseBits, i32) + ), + builder->makeGetLocal(tmp, i32) + ); + freeTemp(tmp); + freeTemp(falseBits); + setOutParam(curr, highBits); + } + + void visitIf(If* curr) { + visitBranching<If>(curr); + } + + void visitLoop(Loop* curr) { + assert(labelIndices.find(curr->name) == labelIndices.end()); + if (curr->type != i64) return; + curr->type = i32; + setOutParam(curr, fetchOutParam(curr->body)); + } + + void visitBreak(Break* curr) { + if (!hasOutParam(curr->value)) return; + assert(curr->value != nullptr); + Index valHighBits = fetchOutParam(curr->value); + auto blockHighBitsIt = labelIndices.find(curr->name); + if (blockHighBitsIt == labelIndices.end()) { + labelIndices[curr->name] = valHighBits; + curr->type = i32; + return; + } + Index blockHighBits = blockHighBitsIt->second; + Index tmp = getTemp(); + SetLocal* setLow = builder->makeSetLocal(tmp, curr->value); + SetLocal* setHigh = builder->makeSetLocal( + blockHighBits, + builder->makeGetLocal(valHighBits, i32) + ); + curr->value = builder->makeGetLocal(tmp, i32); + curr->type = i32; + replaceCurrent(builder->blockify(setLow, setHigh, curr)); + freeTemp(tmp); + } + + void visitSwitch(Switch* curr) { + if (!hasOutParam(curr->value)) return; + Index outParam = fetchOutParam(curr->value); + Index tmp = getTemp(); + bool didReuseOutParam = false; + Expression* result = curr; + std::vector<Name> targets; + auto processTarget = [&](Name target) -> Name { + auto labelIt = labelIndices.find(target); + if (labelIt == labelIndices.end() || labelIt->second == outParam) { + labelIndices[target] = outParam; + didReuseOutParam = true; + return target; + } + Index labelOutParam = labelIt->second; + Name newLabel("$i64toi32_" + std::string(target.c_str())); + result = builder->blockify( + builder->makeSetLocal(tmp, builder->makeBlock(newLabel, result)), + builder->makeSetLocal( + labelOutParam, + builder->makeGetLocal(outParam, i32) + ), + builder->makeGetLocal(tmp, i32) + ); + assert(result->type == i32); + return newLabel; + }; + for (auto target : curr->targets) { + targets.push_back(processTarget(target)); + } + curr->default_ = processTarget(curr->default_); + replaceCurrent(result); + freeTemp(tmp); + if (!didReuseOutParam) { + freeTemp(outParam); + } + } + + template <typename T> + using BuilderFunc = std::function<T*(std::vector<Expression*>&, WasmType)>; + + template <typename T> + void visitGenericCall(T* curr, BuilderFunc<T> callBuilder) { + std::vector<Expression*> args; + for (auto* e : curr->operands) { + args.push_back(e); + if (hasOutParam(e)) { + Index argHighBits = fetchOutParam(e); + args.push_back(builder->makeGetLocal(argHighBits, i32)); + freeTemp(argHighBits); + } + } + if (curr->type != i64) { + replaceCurrent(callBuilder(args, curr->type)); + return; + } + Index lowBits = getTemp(); + Index highBits = getTemp(); + SetLocal* doCall = builder->makeSetLocal( + lowBits, + callBuilder(args, i32) + ); + SetLocal* setHigh = builder->makeSetLocal( + highBits, + builder->makeGetGlobal(highBitsGlobal, i32) + ); + GetLocal* getLow = builder->makeGetLocal(lowBits, i32); + Block* result = builder->blockify(doCall, setHigh, getLow); + freeTemp(lowBits); + setOutParam(result, highBits); + replaceCurrent(result); + } + void visitCall(Call* curr) { + visitGenericCall<Call>( + curr, + [&](std::vector<Expression*>& args, WasmType ty) { + return builder->makeCall(curr->target, args, ty); + } + ); + } + + void visitCallImport(CallImport* curr) { + // imports cannot contain i64s + return; + } + + void visitCallIndirect(CallIndirect* curr) { + visitGenericCall<CallIndirect>( + curr, + [&](std::vector<Expression*>& args, WasmType ty) { + return builder->makeCallIndirect( + curr->fullType, + curr->target, + args, + ty + ); + } + ); + } + + void visitGetLocal(GetLocal* curr) { + if (curr->type != i64) return; + curr->index = indexMap[curr->index]; + curr->type = i32; + Index highBits = getTemp(); + SetLocal *setHighBits = builder->makeSetLocal( + highBits, + builder->makeGetLocal( + curr->index + 1, + i32 + ) + ); + Block* result = builder->blockify(setHighBits, curr); + replaceCurrent(result); + setOutParam(result, highBits); + } + + void lowerTee(SetLocal* curr) { + Index highBits = fetchOutParam(curr->value); + Index tmp = getTemp(); + curr->index = indexMap[curr->index]; + curr->type = i32; + SetLocal* setLow = builder->makeSetLocal(tmp, curr); + SetLocal* setHigh = builder->makeSetLocal( + curr->index + 1, + builder->makeGetLocal(highBits, i32) + ); + GetLocal* getLow = builder->makeGetLocal(tmp, i32); + Block* result = builder->blockify(setLow, setHigh, getLow); + replaceCurrent(result); + setOutParam(result, highBits); + freeTemp(tmp); + } + + void visitSetLocal(SetLocal* curr) { + if (!hasOutParam(curr->value)) return; + if (curr->isTee()) { + lowerTee(curr); + return; + } + Index highBits = fetchOutParam(curr->value); + curr->index = indexMap[curr->index]; + SetLocal* setHigh = builder->makeSetLocal( + curr->index + 1, + builder->makeGetLocal(highBits, i32) + ); + Block* result = builder->blockify(curr, setHigh); + replaceCurrent(result); + freeTemp(highBits); + } + + void visitGetGlobal(GetGlobal* curr) { + assert(false && "GetGlobal not implemented"); + } + + void visitSetGlobal(SetGlobal* curr) { + assert(false && "SetGlobal not implemented"); + } + + void visitLoad(Load* curr) { + if (curr->type != i64) return; + assert(!curr->isAtomic && "atomic load not implemented"); + Index highBits = getTemp(); + Index ptrTemp = getTemp(); + SetLocal* setPtr = builder->makeSetLocal(ptrTemp, curr->ptr); + SetLocal* loadHigh; + if (curr->bytes == 8) { + loadHigh = builder->makeSetLocal( + highBits, + builder->makeLoad( + 4, + curr->signed_, + curr->offset + 4, + 1, + builder->makeGetLocal(ptrTemp, i32), + i32 + ) + ); + } else { + loadHigh = builder->makeSetLocal( + highBits, + builder->makeConst(Literal(int32_t(0))) + ); + } + curr->type = i32; + curr->bytes = std::min(curr->bytes, uint8_t(4)); + curr->align = std::min(uint32_t(curr->align), uint32_t(4)); + curr->ptr = builder->makeGetLocal(ptrTemp, i32); + Block* result = builder->blockify(setPtr, loadHigh, curr); + replaceCurrent(result); + setOutParam(result, highBits); + freeTemp(ptrTemp); + } + + void visitStore(Store* curr) { + if (!hasOutParam(curr->value)) return; + assert(curr->offset + 4 > curr->offset); + assert(!curr->isAtomic && "atomic store not implemented"); + Index highBits = fetchOutParam(curr->value); + uint8_t bytes = curr->bytes; + curr->bytes = std::min(curr->bytes, uint8_t(4)); + curr->align = std::min(uint32_t(curr->align), uint32_t(4)); + curr->valueType = i32; + if (bytes == 8) { + Index ptrTemp = getTemp(); + SetLocal* setPtr = builder->makeSetLocal(ptrTemp, curr->ptr); + curr->ptr = builder->makeGetLocal(ptrTemp, i32); + Store* storeHigh = builder->makeStore( + 4, + curr->offset + 4, + 1, + builder->makeGetLocal(ptrTemp, i32), + builder->makeGetLocal(highBits, i32), + i32 + ); + replaceCurrent(builder->blockify(setPtr, curr, storeHigh)); + freeTemp(ptrTemp); + } + freeTemp(highBits); + } + + void visitAtomicRMW(AtomicRMW* curr) { + assert(false && "AtomicRMW not implemented"); + } + + void visitAtomicCmpxchg(AtomicCmpxchg* curr) { + assert(false && "AtomicCmpxchg not implemented"); + } + + void visitConst(Const* curr) { + if (curr->type != i64) return; + Index highBits = getTemp(); + Const* lowVal = builder->makeConst( + Literal(int32_t(curr->value.geti64() & 0xffffffff)) + ); + SetLocal* setHigh = builder->makeSetLocal( + highBits, + builder->makeConst( + Literal(int32_t(uint64_t(curr->value.geti64()) >> 32)) + ) + ); + Block* result = builder->blockify(setHigh, lowVal); + setOutParam(result, highBits); + replaceCurrent(result); + } + + void lowerEqZInt64(Unary* curr) { + Index highBits = fetchOutParam(curr->value); + replaceCurrent( + builder->makeBinary( + AndInt32, + builder->makeUnary(EqZInt32, builder->makeGetLocal(highBits, i32)), + builder->makeUnary(EqZInt32, curr->value) + ) + ); + freeTemp(highBits); + } + + void lowerExtendUInt32(Unary* curr) { + Index highBits = getTemp(); + Block* result = builder->blockify( + builder->makeSetLocal(highBits, builder->makeConst(Literal(int32_t(0)))), + curr->value + ); + setOutParam(result, highBits); + replaceCurrent(result); + } + + void lowerWrapInt64(Unary* curr) { + freeTemp(fetchOutParam(curr->value)); + replaceCurrent(curr->value); + } + + bool unaryNeedsLowering(UnaryOp op) { + switch (op) { + case ClzInt64: + case CtzInt64: + case PopcntInt64: + case EqZInt64: + case ExtendSInt32: + case ExtendUInt32: + case WrapInt64: + case TruncSFloat32ToInt64: + case TruncUFloat32ToInt64: + case TruncSFloat64ToInt64: + case TruncUFloat64ToInt64: + case ReinterpretFloat64: + case ConvertSInt64ToFloat32: + case ConvertSInt64ToFloat64: + case ConvertUInt64ToFloat32: + case ConvertUInt64ToFloat64: + case ReinterpretInt64: return true; + default: return false; + } + } + + void visitUnary(Unary* curr) { + if (!unaryNeedsLowering(curr->op)) return; + if (curr->type == unreachable || curr->value->type == unreachable) { + assert(!hasOutParam(curr->value)); + replaceCurrent(curr->value); + return; + } + assert(hasOutParam(curr->value) || curr->type == i64); + switch (curr->op) { + case ClzInt64: + case CtzInt64: + case PopcntInt64: goto err; + case EqZInt64: lowerEqZInt64(curr); break; + case ExtendSInt32: goto err; + case ExtendUInt32: lowerExtendUInt32(curr); break; + case WrapInt64: lowerWrapInt64(curr); break; + case TruncSFloat32ToInt64: + case TruncUFloat32ToInt64: + case TruncSFloat64ToInt64: + case TruncUFloat64ToInt64: + case ReinterpretFloat64: + case ConvertSInt64ToFloat32: + case ConvertSInt64ToFloat64: + case ConvertUInt64ToFloat32: + case ConvertUInt64ToFloat64: + case ReinterpretInt64: + err: default: + std::cerr << "Unhandled unary operator: " << curr->op << std::endl; + abort(); + } + } + + Block* lowerAdd(Block* result, Index leftLow, Index leftHigh, + Index rightLow, Index rightHigh) { + SetLocal* addLow = builder->makeSetLocal( + leftHigh, + builder->makeBinary( + AddInt32, + builder->makeGetLocal(leftLow, i32), + builder->makeGetLocal(rightLow, i32) + ) + ); + SetLocal* addHigh = builder->makeSetLocal( + rightHigh, + builder->makeBinary( + AddInt32, + builder->makeGetLocal(leftHigh, i32), + builder->makeGetLocal(rightHigh, i32) + ) + ); + SetLocal* carryBit = builder->makeSetLocal( + rightHigh, + builder->makeBinary( + AddInt32, + builder->makeGetLocal(rightHigh, i32), + builder->makeConst(Literal(int32_t(1))) + ) + ); + If* checkOverflow = builder->makeIf( + builder->makeBinary( + LtUInt32, + builder->makeGetLocal(leftLow, i32), + builder->makeGetLocal(rightLow, i32) + ), + carryBit + ); + GetLocal* getLow = builder->makeGetLocal(leftHigh, i32); + result = builder->blockify(result, addLow, addHigh, checkOverflow, getLow); + freeTemp(leftLow); + freeTemp(leftHigh); + freeTemp(rightLow); + setOutParam(result, rightHigh); + return result; + } + + Block* lowerMul(Block* result, Index leftLow, Index leftHigh, Index rightLow, + Index rightHigh) { + // high bits = ll*rh + lh*rl + ll1*rl1 + (ll0*rl1)>>16 + (ll1*rl0)>>16 + // low bits = ll*rl + Index leftLow0 = getTemp(); + Index leftLow1 = getTemp(); + Index rightLow0 = getTemp(); + Index rightLow1 = getTemp(); + SetLocal* setLL0 = builder->makeSetLocal( + leftLow0, + builder->makeBinary( + AndInt32, + builder->makeGetLocal(leftLow, i32), + builder->makeConst(Literal(int32_t(0xffff))) + ) + ); + SetLocal* setLL1 = builder->makeSetLocal( + leftLow1, + builder->makeBinary( + ShrUInt32, + builder->makeGetLocal(leftLow, i32), + builder->makeConst(Literal(int32_t(16))) + ) + ); + SetLocal* setRL0 = builder->makeSetLocal( + rightLow0, + builder->makeBinary( + AndInt32, + builder->makeGetLocal(rightLow, i32), + builder->makeConst(Literal(int32_t(0xffff))) + ) + ); + SetLocal* setRL1 = builder->makeSetLocal( + rightLow1, + builder->makeBinary( + ShrUInt32, + builder->makeGetLocal(rightLow, i32), + builder->makeConst(Literal(int32_t(16))) + ) + ); + SetLocal* setLLRH = builder->makeSetLocal( + rightHigh, + builder->makeBinary( + MulInt32, + builder->makeGetLocal(leftLow, i32), + builder->makeGetLocal(rightHigh, i32) + ) + ); + auto addToHighBits = [&](Expression* expr) -> SetLocal* { + return builder->makeSetLocal( + rightHigh, + builder->makeBinary( + AddInt32, + builder->makeGetLocal(rightHigh, i32), + expr + ) + ); + }; + SetLocal* addLHRL = addToHighBits( + builder->makeBinary( + MulInt32, + builder->makeGetLocal(leftHigh, i32), + builder->makeGetLocal(rightLow, i32) + ) + ); + SetLocal* addLL1RL1 = addToHighBits( + builder->makeBinary( + MulInt32, + builder->makeGetLocal(leftLow1, i32), + builder->makeGetLocal(rightLow1, i32) + ) + ); + SetLocal* addLL0RL1 = addToHighBits( + builder->makeBinary( + ShrUInt32, + builder->makeBinary( + MulInt32, + builder->makeGetLocal(leftLow0, i32), + builder->makeGetLocal(rightLow1, i32) + ), + builder->makeConst(Literal(int32_t(16))) + ) + ); + SetLocal* addLL1RL0 = addToHighBits( + builder->makeBinary( + ShrUInt32, + builder->makeBinary( + MulInt32, + builder->makeGetLocal(leftLow1, i32), + builder->makeGetLocal(rightLow0, i32) + ), + builder->makeConst(Literal(int32_t(16))) + ) + ); + Binary* getLow = builder->makeBinary( + MulInt32, + builder->makeGetLocal(leftLow, i32), + builder->makeGetLocal(rightLow, i32) + ); + result = builder->blockify( + result, + setLL0, + setLL1, + setRL0, + setRL1, + setLLRH, + addLHRL, + addLL1RL1, + addLL0RL1, + addLL1RL0, + getLow + ); + freeTemp(leftLow0); + freeTemp(leftLow1); + freeTemp(rightLow0); + freeTemp(rightLow1); + freeTemp(leftLow); + freeTemp(leftHigh); + freeTemp(rightLow); + setOutParam(result, rightHigh); + return result; + } + + Block* lowerBitwise(BinaryOp op, Block* result, Index leftLow, Index leftHigh, + Index rightLow, Index rightHigh) { + BinaryOp op32; + switch (op) { + case AndInt64: op32 = AndInt32; break; + case OrInt64: op32 = OrInt32; break; + case XorInt64: op32 = XorInt32; break; + default: abort(); + } + result = builder->blockify( + result, + builder->makeSetLocal( + rightHigh, + builder->makeBinary( + op32, + builder->makeGetLocal(leftHigh, i32), + builder->makeGetLocal(rightHigh, i32) + ) + ), + builder->makeBinary( + op32, + builder->makeGetLocal(leftLow, i32), + builder->makeGetLocal(rightLow, i32) + ) + ); + freeTemp(leftLow); + freeTemp(leftHigh); + freeTemp(rightLow); + setOutParam(result, rightHigh); + return result; + } + + Block* makeLargeShl(Index highBits, Index leftLow, Index shift) { + return builder->blockify( + builder->makeSetLocal( + highBits, + builder->makeBinary( + ShlInt32, + builder->makeGetLocal(leftLow, i32), + builder->makeGetLocal(shift, i32) + ) + ), + builder->makeConst(Literal(int32_t(0))) + ); + } + + Block* makeLargeShrU(Index highBits, Index leftHigh, Index shift) { + return builder->blockify( + builder->makeSetLocal(highBits, builder->makeConst(Literal(int32_t(0)))), + builder->makeBinary( + ShrUInt32, + builder->makeGetLocal(leftHigh, i32), + builder->makeGetLocal(shift, i32) + ) + ); + } + + Block* makeSmallShl(Index highBits, Index leftLow, Index leftHigh, + Index shift, Binary* shiftMask, Binary* widthLessShift) { + Binary* shiftedInBits = builder->makeBinary( + AndInt32, + shiftMask, + builder->makeBinary( + ShrUInt32, + builder->makeGetLocal(leftLow, i32), + widthLessShift + ) + ); + Binary* shiftHigh = builder->makeBinary( + ShlInt32, + builder->makeGetLocal(leftHigh, i32), + builder->makeGetLocal(shift, i32) + ); + return builder->blockify( + builder->makeSetLocal( + highBits, + builder->makeBinary(OrInt32, shiftedInBits, shiftHigh) + ), + builder->makeBinary( + ShlInt32, + builder->makeGetLocal(leftLow, i32), + builder->makeGetLocal(shift, i32) + ) + ); + } + + Block* makeSmallShrU(Index highBits, Index leftLow, Index leftHigh, + Index shift, Binary* shiftMask, Binary* widthLessShift) { + Binary* shiftedInBits = builder->makeBinary( + ShlInt32, + builder->makeBinary( + AndInt32, + shiftMask, + builder->makeGetLocal(leftHigh, i32) + ), + widthLessShift + ); + Binary* shiftLow = builder->makeBinary( + ShrUInt32, + builder->makeGetLocal(leftLow, i32), + builder->makeGetLocal(shift, i32) + ); + return builder->blockify( + builder->makeSetLocal( + highBits, + builder->makeBinary( + ShrUInt32, + builder->makeGetLocal(leftHigh, i32), + builder->makeGetLocal(shift, i32) + ) + ), + builder->makeBinary(OrInt32, shiftedInBits, shiftLow) + ); + } + + Block* lowerShU(BinaryOp op, Block* result, Index leftLow, Index leftHigh, + Index rightLow, Index rightHigh) { + assert(op == ShlInt64 || op == ShrUInt64); + // shift left lowered as: + // if 32 <= rightLow % 64: + // high = leftLow << k; low = 0 + // else: + // high = (((1 << k) - 1) & (leftLow >> (32 - k))) | (leftHigh << k); + // low = leftLow << k + // where k = shift % 32. shift right is similar. + Index shift = getTemp(); + SetLocal* setShift = builder->makeSetLocal( + shift, + builder->makeBinary( + AndInt32, + builder->makeGetLocal(rightLow, i32), + builder->makeConst(Literal(int32_t(32 - 1))) + ) + ); + Binary* isLargeShift = builder->makeBinary( + LeUInt32, + builder->makeConst(Literal(int32_t(32))), + builder->makeBinary( + AndInt32, + builder->makeGetLocal(rightLow, i32), + builder->makeConst(Literal(int32_t(64 - 1))) + ) + ); + Block* largeShiftBlock; + switch (op) { + case ShlInt64: + largeShiftBlock = makeLargeShl(rightHigh, leftLow, shift); break; + case ShrUInt64: + largeShiftBlock = makeLargeShrU(rightHigh, leftHigh, shift); break; + default: abort(); + } + Binary* shiftMask = builder->makeBinary( + SubInt32, + builder->makeBinary( + ShlInt32, + builder->makeConst(Literal(int32_t(1))), + builder->makeGetLocal(shift, i32) + ), + builder->makeConst(Literal(int32_t(1))) + ); + Binary* widthLessShift = builder->makeBinary( + SubInt32, + builder->makeConst(Literal(int32_t(32))), + builder->makeGetLocal(shift, i32) + ); + Block* smallShiftBlock; + switch(op) { + case ShlInt64: { + smallShiftBlock = makeSmallShl(rightHigh, leftLow, leftHigh, + shift, shiftMask, widthLessShift); + break; + } + case ShrUInt64: { + smallShiftBlock = makeSmallShrU(rightHigh, leftLow, leftHigh, + shift, shiftMask, widthLessShift); + break; + } + default: abort(); + } + If* ifLargeShift = builder->makeIf( + isLargeShift, + largeShiftBlock, + smallShiftBlock + ); + result = builder->blockify(result, setShift, ifLargeShift); + freeTemp(shift); + freeTemp(leftLow); + freeTemp(leftHigh); + freeTemp(rightLow); + setOutParam(result, rightHigh); + return result; + } + + Block* lowerEq(Block* result, Index leftLow, Index leftHigh, + Index rightLow, Index rightHigh) { + freeTemp(leftLow); + freeTemp(leftHigh); + freeTemp(rightLow); + freeTemp(rightHigh); + return builder->blockify( + result, + builder->makeBinary( + AndInt32, + builder->makeBinary( + EqInt32, + builder->makeGetLocal(leftLow, i32), + builder->makeGetLocal(rightLow, i32) + ), + builder->makeBinary( + EqInt32, + builder->makeGetLocal(leftHigh, i32), + builder->makeGetLocal(rightHigh, i32) + ) + ) + ); + } + + Block* lowerNe(Block* result, Index leftLow, Index leftHigh, + Index rightLow, Index rightHigh) { + freeTemp(leftLow); + freeTemp(leftHigh); + freeTemp(rightLow); + freeTemp(rightHigh); + return builder->blockify( + result, + builder->makeBinary( + OrInt32, + builder->makeBinary( + NeInt32, + builder->makeGetLocal(leftLow, i32), + builder->makeGetLocal(rightLow, i32) + ), + builder->makeBinary( + NeInt32, + builder->makeGetLocal(leftHigh, i32), + builder->makeGetLocal(rightHigh, i32) + ) + ) + ); + } + + Block* lowerUComp(BinaryOp op, Block* result, Index leftLow, Index leftHigh, + Index rightLow, Index rightHigh) { + freeTemp(leftLow); + freeTemp(leftHigh); + freeTemp(rightLow); + freeTemp(rightHigh); + BinaryOp highOp, lowOp; + switch (op) { + case LtUInt64: highOp = LtUInt32; lowOp = LtUInt32; break; + case LeUInt64: highOp = LtUInt32; lowOp = LeUInt32; break; + case GtUInt64: highOp = GtUInt32; lowOp = GtUInt32; break; + case GeUInt64: highOp = GtUInt32; lowOp = GeUInt32; break; + default: abort(); + } + Binary* compHigh = builder->makeBinary( + highOp, + builder->makeGetLocal(leftHigh, i32), + builder->makeGetLocal(rightHigh, i32) + ); + Binary* eqHigh = builder->makeBinary( + EqInt32, + builder->makeGetLocal(leftHigh, i32), + builder->makeGetLocal(rightHigh, i32) + ); + Binary* compLow = builder->makeBinary( + lowOp, + builder->makeGetLocal(leftLow, i32), + builder->makeGetLocal(rightLow, i32) + ); + return builder->blockify( + result, + builder->makeBinary( + OrInt32, + compHigh, + builder->makeBinary(AndInt32, eqHigh, compLow) + ) + ); + } + + bool binaryNeedsLowering(BinaryOp op) { + switch (op) { + case AddInt64: + case SubInt64: + case MulInt64: + case DivSInt64: + case DivUInt64: + case RemSInt64: + case RemUInt64: + case AndInt64: + case OrInt64: + case XorInt64: + case ShlInt64: + case ShrUInt64: + case ShrSInt64: + case RotLInt64: + case RotRInt64: + case EqInt64: + case NeInt64: + case LtSInt64: + case LtUInt64: + case LeSInt64: + case LeUInt64: + case GtSInt64: + case GtUInt64: + case GeSInt64: + case GeUInt64: return true; + default: return false; + } + } + + void visitBinary(Binary* curr) { + if (!binaryNeedsLowering(curr->op)) return; + if (!hasOutParam(curr->left)) { + // left unreachable, replace self with left + replaceCurrent(curr->left); + if (hasOutParam(curr->right)) { + freeTemp(fetchOutParam(curr->right)); + } + return; + } + if (!hasOutParam(curr->right)) { + // right unreachable, replace self with left then right + replaceCurrent( + builder->blockify(builder->makeDrop(curr->left), curr->right) + ); + freeTemp(fetchOutParam(curr->left)); + return; + } + // left and right reachable, lower normally + Index leftLow = getTemp(); + Index leftHigh = fetchOutParam(curr->left); + Index rightLow = getTemp(); + Index rightHigh = fetchOutParam(curr->right); + SetLocal* setRight = builder->makeSetLocal(rightLow, curr->right); + SetLocal* setLeft = builder->makeSetLocal(leftLow, curr->left); + Block* result = builder->blockify(setLeft, setRight); + switch (curr->op) { + case AddInt64: { + replaceCurrent(lowerAdd(result, leftLow, leftHigh, rightLow, + rightHigh)); + break; + } + case SubInt64: goto err; + case MulInt64: { + replaceCurrent(lowerMul(result, leftLow, leftHigh, rightLow, + rightHigh)); + break; + } + case DivSInt64: + case DivUInt64: + case RemSInt64: + case RemUInt64: goto err; + case AndInt64: + case OrInt64: + case XorInt64: { + replaceCurrent(lowerBitwise(curr->op, result, leftLow, leftHigh, + rightLow, rightHigh)); + break; + } + case ShlInt64: + case ShrUInt64: { + replaceCurrent(lowerShU(curr->op, result, leftLow, leftHigh, rightLow, + rightHigh)); + break; + } + case ShrSInt64: + case RotLInt64: + case RotRInt64: goto err; + case EqInt64: { + replaceCurrent(lowerEq(result, leftLow, leftHigh, rightLow, rightHigh)); + break; + } + case NeInt64: { + replaceCurrent(lowerNe(result, leftLow, leftHigh, rightLow, rightHigh)); + break; + } + case LtSInt64: + case LeSInt64: + case GtSInt64: + case GeSInt64: goto err; + case LtUInt64: + case LeUInt64: + case GtUInt64: + case GeUInt64: { + replaceCurrent(lowerUComp(curr->op, result, leftLow, leftHigh, rightLow, + rightHigh)); + break; + } + err: default: { + std::cerr << "Unhandled binary op " << curr->op << std::endl; + abort(); + } + } + } + + void visitSelect(Select* curr) { + visitBranching<Select>(curr); + } + + void visitDrop(Drop* curr) { + if (!hasOutParam(curr->value)) return; + freeTemp(fetchOutParam(curr->value)); + } + + void visitReturn(Return* curr) { + if (!hasOutParam(curr->value)) return; + Index lowBits = getTemp(); + Index highBits = fetchOutParam(curr->value); + SetLocal* setLow = builder->makeSetLocal(lowBits, curr->value); + SetGlobal* setHigh = builder->makeSetGlobal( + highBitsGlobal, + builder->makeGetLocal(highBits, i32) + ); + curr->value = builder->makeGetLocal(lowBits, i32); + Block* result = builder->blockify(setLow, setHigh, curr); + replaceCurrent(result); + freeTemp(lowBits); + freeTemp(highBits); + } + +private: + std::unique_ptr<Builder> builder; + std::unordered_map<Index, Index> indexMap; + std::unordered_map<Expression*, Index> returnIndices; + std::unordered_map<Name, Index> labelIndices; + std::vector<Index> freeTemps; + Index nextTemp; + + // TODO: RAII for temp var allocation + Index getTemp() { + Index ret; + if (freeTemps.size() > 0) { + ret = freeTemps.back(); + freeTemps.pop_back(); + } else { + ret = nextTemp++; + } + return ret; + } + + void freeTemp(Index t) { + assert(std::find(freeTemps.begin(), freeTemps.end(), t) == freeTemps.end()); + freeTemps.push_back(t); + } + + bool hasOutParam(Expression* e) { + return returnIndices.find(e) != returnIndices.end(); + } + + void setOutParam(Expression* e, Index idx) { + returnIndices[e] = idx; + } + + Index fetchOutParam(Expression* e) { + assert(returnIndices.find(e) != returnIndices.end()); + Index ret = returnIndices[e]; + returnIndices.erase(e); + return ret; + } +}; + +Name I64ToI32Lowering::highBitsGlobal("i64toi32_i32$HIGH_BITS"); + +Pass *createI64ToI32LoweringPass() { + return new I64ToI32Lowering(); +} + +} diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index 4ffd74047..728fe2371 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -77,6 +77,7 @@ void PassRegistry::registerPasses() { registerPass("legalize-js-interface", "legalizes i64 types on the import/export boundary", createLegalizeJSInterfacePass); registerPass("local-cse", "common subexpression elimination inside basic blocks", createLocalCSEPass); registerPass("log-execution", "instrument the build with logging of where execution goes", createLogExecutionPass); + registerPass("i64-to-i32-lowering", "lower all uses of i64s to use i32s instead", createI64ToI32LoweringPass); registerPass("instrument-locals", "instrument the build with code to intercept all loads and stores", createInstrumentLocalsPass); registerPass("instrument-memory", "instrument the build with code to intercept all loads and stores", createInstrumentMemoryPass); registerPass("memory-packing", "packs memory into separate segments, skipping zeros", createMemoryPackingPass); diff --git a/src/passes/passes.h b/src/passes/passes.h index 18c92b2cb..b1a6cb5c6 100644 --- a/src/passes/passes.h +++ b/src/passes/passes.h @@ -31,6 +31,7 @@ Pass *createDuplicateFunctionEliminationPass(); Pass *createExtractFunctionPass(); Pass *createFlattenControlFlowPass(); Pass *createFullPrinterPass(); +Pass *createI64ToI32LoweringPass(); Pass *createInliningPass(); Pass *createInliningOptimizingPass(); Pass *createLegalizeJSInterfacePass(); diff --git a/src/wasm-builder.h b/src/wasm-builder.h index 7432562d1..5acac65ee 100644 --- a/src/wasm-builder.h +++ b/src/wasm-builder.h @@ -311,6 +311,7 @@ public: static Index addParam(Function* func, Name name, WasmType type) { // only ok to add a param if no vars, otherwise indices are invalidated assert(func->localIndices.size() == func->params.size()); + assert(name.is()); func->params.push_back(type); Index index = func->localNames.size(); func->localIndices[name] = index; @@ -321,12 +322,8 @@ public: static Index addVar(Function* func, Name name, WasmType type) { // always ok to add a var, it does not affect other indices Index index = func->getNumLocals(); - if (name.is()) { - // if there is a name, apply it, but here we assume all the rest have names too FIXME - assert(func->localIndices.size() == func->params.size() + func->vars.size()); - func->localIndices[name] = index; - func->localNames.push_back(name); - } + if (name.is()) func->localIndices[name] = index; + func->localNames.push_back(name); func->vars.emplace_back(type); return index; } diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h index 6e46b8cb9..44bd6ed3d 100644 --- a/src/wasm-traversal.h +++ b/src/wasm-traversal.h @@ -275,6 +275,7 @@ struct Walker : public VisitorType { }; void pushTask(TaskFunc func, Expression** currp) { + assert(*currp); stack.emplace_back(func, currp); } void maybePushTask(TaskFunc func, Expression** currp) { diff --git a/src/wasm.h b/src/wasm.h index e9782ab6e..78aa1577b 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -614,7 +614,6 @@ public: Name getLocalNameOrDefault(Index index); Name getLocalNameOrGeneric(Index index); -private: bool hasLocalName(Index index) const; }; diff --git a/src/wasm2asm.h b/src/wasm2asm.h index ce9b99d0f..ef3ab0c0d 100644 --- a/src/wasm2asm.h +++ b/src/wasm2asm.h @@ -31,6 +31,8 @@ #include "emscripten-optimizer/optimizer.h" #include "mixed_arena.h" #include "asm_v_wasm.h" +#include "ast_utils.h" +#include "passes/passes.h" namespace wasm { @@ -148,6 +150,7 @@ public: } return ret; } + // Free a temp var. void freeTemp(WasmType type, IString temp) { frees[type].push_back(temp); @@ -155,18 +158,20 @@ public: static IString fromName(Name name) { // TODO: more clever name fixing, including checking we do not collide - const char *str = name.str; + const char* str = name.str; // check the various issues, and recurse so we check the others if (strchr(str, '-')) { - char *mod = strdup(str); // XXX leak + char* mod = strdup(str); str = mod; while (*mod) { if (*mod == '-') *mod = '_'; mod++; } - return fromName(IString(str, false)); + IString result = fromName(IString(str, false)); + free((void*)str); + return result; } - if (isdigit(str[0])) { + if (isdigit(str[0]) || strcmp(str, "if") == 0) { std::string prefixed = "$$"; prefixed += name.str; return fromName(IString(prefixed.c_str(), false)); @@ -177,6 +182,7 @@ public: void setStatement(Expression* curr) { willBeStatement.insert(curr); } + bool isStatement(Expression* curr) { return curr && willBeStatement.find(curr) != willBeStatement.end(); } @@ -200,10 +206,10 @@ private: size_t tableSize; void addBasics(Ref ast); - void addImport(Ref ast, Import *import); - void addTables(Ref ast, Module *wasm); - void addExports(Ref ast, Module *wasm); - void addWasmCompatibilityFuncs(Module *wasm); + void addImport(Ref ast, Import* import); + void addTables(Ref ast, Module* wasm); + void addExports(Ref ast, Module* wasm); + void addWasmCompatibilityFuncs(Module* wasm); bool isAssertHandled(Element& e); Ref makeAssertReturnFunc(SExpressionWasmBuilder& sexpBuilder, Builder& wasmBuilder, @@ -377,6 +383,13 @@ void Wasm2AsmBuilder::addWasmCompatibilityFuncs(Module* wasm) { Ref Wasm2AsmBuilder::processWasm(Module* wasm) { addWasmCompatibilityFuncs(wasm); + PassRunner runner(wasm); + runner.add<AutoDrop>(); + runner.add("i64-to-i32-lowering"); + runner.add("flatten-control-flow"); + runner.add("vacuum"); + runner.setDebug(flags.debug); + runner.run(); Ref ret = ValueBuilder::makeToplevel(); Ref asmFunc = ValueBuilder::makeFunction(ASM_FUNC); ret[1]->push_back(asmFunc); @@ -457,7 +470,7 @@ void Wasm2AsmBuilder::addBasics(Ref ast) { addMath(MATH_CLZ32, CLZ32); } -void Wasm2AsmBuilder::addImport(Ref ast, Import *import) { +void Wasm2AsmBuilder::addImport(Ref ast, Import* import) { Ref theVar = ValueBuilder::makeVar(); ast->push_back(theVar); Ref module = ValueBuilder::makeName(ENV); // TODO: handle nested module imports @@ -470,7 +483,7 @@ void Wasm2AsmBuilder::addImport(Ref ast, Import *import) { ); } -void Wasm2AsmBuilder::addTables(Ref ast, Module *wasm) { +void Wasm2AsmBuilder::addTables(Ref ast, Module* wasm) { std::map<std::string, std::vector<IString>> tables; // asm.js tables, sig => contents of table for (Table::Segment& seg : wasm->table.segments) { for (size_t i = 0; i < seg.data.size(); i++) { @@ -505,16 +518,24 @@ void Wasm2AsmBuilder::addTables(Ref ast, Module *wasm) { } } -void Wasm2AsmBuilder::addExports(Ref ast, Module *wasm) { +void Wasm2AsmBuilder::addExports(Ref ast, Module* wasm) { Ref exports = ValueBuilder::makeObject(); for (auto& export_ : wasm->exports) { - ValueBuilder::appendToObject(exports, fromName(export_->name), ValueBuilder::makeName(fromName(export_->value))); + ValueBuilder::appendToObject( + exports, + fromName(export_->name), + ValueBuilder::makeName(fromName(export_->value)) + ); } ast->push_back(ValueBuilder::makeStatement(ValueBuilder::makeReturn(exports))); } Ref Wasm2AsmBuilder::processFunction(Function* func) { - if (flags.debug) std::cerr << " processFunction " << func->name << std::endl; + if (flags.debug) { + static int fns = 0; + std::cerr << "processFunction " << (fns++) << " " << func->name + << std::endl; + } Ref ret = ValueBuilder::makeFunction(fromName(func->name)); frees.clear(); frees.resize(std::max(i32, std::max(f32, f64)) + 1); @@ -529,8 +550,10 @@ Ref Wasm2AsmBuilder::processFunction(Function* func) { ValueBuilder::makeStatement( ValueBuilder::makeBinary( ValueBuilder::makeName(name), SET, - makeAsmCoercion(ValueBuilder::makeName(name), - wasmToAsmType(func->getLocalType(i))) + makeAsmCoercion( + ValueBuilder::makeName(name), + wasmToAsmType(func->getLocalType(i)) + ) ) ) ); @@ -539,27 +562,46 @@ Ref Wasm2AsmBuilder::processFunction(Function* func) { size_t theVarIndex = ret[3]->size(); ret[3]->push_back(theVar); // body + auto appendFinalReturn = [&] (Ref retVal) { + flattenAppend( + ret, + ValueBuilder::makeReturn( + makeAsmCoercion(retVal, wasmToAsmType(func->result)) + ) + ); + }; scanFunctionBody(func->body); - if (isStatement(func->body)) { + bool isBodyBlock = (func->body->_id == Expression::BlockId); + ExpressionList* stats = isBodyBlock ? + &static_cast<Block*>(func->body)->list : nullptr; + bool endsInReturn = + (isBodyBlock && ((*stats)[stats->size()-1]->_id == Expression::ReturnId)); + if (endsInReturn) { + // return already taken care of + flattenAppend(ret, processFunctionBody(func, NO_RESULT)); + } else if (isStatement(func->body)) { + // store result in variable then return it IString result = - func->result != none ? getTemp(func->result, func) : NO_RESULT; - flattenAppend(ret, ValueBuilder::makeStatement(processFunctionBody(func, result))); + func->result != none ? getTemp(func->result, func) : NO_RESULT; + flattenAppend(ret, processFunctionBody(func, result)); if (func->result != none) { - // do the actual return - ret[3]->push_back(ValueBuilder::makeStatement(ValueBuilder::makeReturn(makeAsmCoercion(ValueBuilder::makeName(result), wasmToAsmType(func->result))))); + appendFinalReturn(ValueBuilder::makeName(result)); freeTemp(func->result, result); } + } else if (func->result != none) { + // whole thing is an expression, just return body + appendFinalReturn(processFunctionBody(func, EXPRESSION_RESULT)); } else { - // whole thing is an expression, just do a return - if (func->result != none) { - ret[3]->push_back(ValueBuilder::makeStatement(ValueBuilder::makeReturn(makeAsmCoercion(processFunctionBody(func, EXPRESSION_RESULT), wasmToAsmType(func->result))))); - } else { - flattenAppend(ret, processFunctionBody(func, NO_RESULT)); - } + // func has no return + flattenAppend(ret, processFunctionBody(func, NO_RESULT)); } // vars, including new temp vars for (Index i = func->getVarIndexBase(); i < func->getNumLocals(); i++) { - ValueBuilder::appendToVar(theVar, fromName(func->getLocalNameOrGeneric(i)), makeAsmCoercedZero(wasmToAsmType(func->getLocalType(i)))); + ValueBuilder::appendToVar( + theVar, + fromName(func->getLocalNameOrGeneric(i)), + makeAsmCoercedZero(wasmToAsmType(func->getLocalType(i))) + ); } if (theVar[1]->size() == 0) { ret[3]->splice(theVarIndex, 1); @@ -581,22 +623,22 @@ void Wasm2AsmBuilder::scanFunctionBody(Expression* curr) { // Visitors - void visitBlock(Block *curr) { + void visitBlock(Block* curr) { parent->setStatement(curr); } - void visitIf(If *curr) { + void visitIf(If* curr) { parent->setStatement(curr); } - void visitLoop(Loop *curr) { + void visitLoop(Loop* curr) { parent->setStatement(curr); } - void visitBreak(Break *curr) { + void visitBreak(Break* curr) { parent->setStatement(curr); } - void visitSwitch(Switch *curr) { + void visitSwitch(Switch* curr) { parent->setStatement(curr); } - void visitCall(Call *curr) { + void visitCall(Call* curr) { for (auto item : curr->operands) { if (parent->isStatement(item)) { parent->setStatement(curr); @@ -604,7 +646,7 @@ void Wasm2AsmBuilder::scanFunctionBody(Expression* curr) { } } } - void visitCallImport(CallImport *curr) { + void visitCallImport(CallImport* curr) { for (auto item : curr->operands) { if (parent->isStatement(item)) { parent->setStatement(curr); @@ -612,7 +654,7 @@ void Wasm2AsmBuilder::scanFunctionBody(Expression* curr) { } } } - void visitCallIndirect(CallIndirect *curr) { + void visitCallIndirect(CallIndirect* curr) { if (parent->isStatement(curr->target)) { parent->setStatement(curr); return; @@ -624,40 +666,40 @@ void Wasm2AsmBuilder::scanFunctionBody(Expression* curr) { } } } - void visitSetLocal(SetLocal *curr) { + void visitSetLocal(SetLocal* curr) { if (parent->isStatement(curr->value)) { parent->setStatement(curr); } } - void visitLoad(Load *curr) { + void visitLoad(Load* curr) { if (parent->isStatement(curr->ptr)) { parent->setStatement(curr); } } - void visitStore(Store *curr) { + void visitStore(Store* curr) { if (parent->isStatement(curr->ptr) || parent->isStatement(curr->value)) { parent->setStatement(curr); } } - void visitUnary(Unary *curr) { + void visitUnary(Unary* curr) { if (parent->isStatement(curr->value)) { parent->setStatement(curr); } } - void visitBinary(Binary *curr) { + void visitBinary(Binary* curr) { if (parent->isStatement(curr->left) || parent->isStatement(curr->right)) { parent->setStatement(curr); } } - void visitSelect(Select *curr) { + void visitSelect(Select* curr) { if (parent->isStatement(curr->ifTrue) || parent->isStatement(curr->ifFalse) || parent->isStatement(curr->condition)) { parent->setStatement(curr); } } - void visitReturn(Return *curr) { - abort(); + void visitReturn(Return* curr) { + parent->setStatement(curr); } - void visitHost(Host *curr) { + void visitHost(Host* curr) { for (auto item : curr->operands) { if (parent->isStatement(item)) { parent->setStatement(curr); @@ -779,7 +821,7 @@ Ref Wasm2AsmBuilder::processFunctionBody(Function* func, IString result) { // Visitors - Ref visitBlock(Block *curr) { + Ref visitBlock(Block* curr) { breakResults[curr->name] = result; Ref ret = ValueBuilder::makeBlock(); size_t size = curr->list.size(); @@ -795,7 +837,8 @@ Ref Wasm2AsmBuilder::processFunctionBody(Function* func, IString result) { } return ret; } - Ref visitIf(If *curr) { + + Ref visitIf(If* curr) { IString temp; Ref condition = visitForExpression(curr->condition, i32, temp); Ref ifTrue = ValueBuilder::makeStatement(visitAndAssign(curr->ifTrue, result)); @@ -811,15 +854,17 @@ Ref Wasm2AsmBuilder::processFunctionBody(Function* func, IString result) { condition[1]->push_back(ValueBuilder::makeIf(ValueBuilder::makeName(temp), ifTrue, ifFalse)); return condition; } - Ref visitLoop(Loop *curr) { + + Ref visitLoop(Loop* curr) { Name asmLabel = curr->name; continueLabels.insert(asmLabel); - Ref body = visit(curr->body, result); - flattenAppend(body, ValueBuilder::makeBreak(asmLabel)); + Ref body = blockify(visit(curr->body, result)); + flattenAppend(body, ValueBuilder::makeBreak(fromName(asmLabel))); Ref ret = ValueBuilder::makeDo(body, ValueBuilder::makeInt(1)); return ValueBuilder::makeLabel(fromName(asmLabel), ret); } - Ref visitBreak(Break *curr) { + + Ref visitBreak(Break* curr) { if (curr->condition) { // we need an equivalent to an if here, so use that code Break fakeBreak = *curr; @@ -843,8 +888,10 @@ Ref Wasm2AsmBuilder::processFunctionBody(Function* func, IString result) { ret[1]->push_back(theBreak); return ret; } - Expression *defaultBody = nullptr; // default must be last in asm.js - Ref visitSwitch(Switch *curr) { + + Expression* defaultBody = nullptr; // default must be last in asm.js + + Ref visitSwitch(Switch* curr) { assert(!curr->value); Ref ret = ValueBuilder::makeBlock(); Ref condition; @@ -855,8 +902,9 @@ Ref Wasm2AsmBuilder::processFunctionBody(Function* func, IString result) { } else { condition = visit(curr->condition, EXPRESSION_RESULT); } - Ref theSwitch = ValueBuilder::makeSwitch(condition); - ret[2][1]->push_back(theSwitch); + Ref theSwitch = + ValueBuilder::makeSwitch(makeAsmCoercion(condition, ASM_INT)); + ret[1]->push_back(theSwitch); for (size_t i = 0; i < curr->targets.size(); i++) { ValueBuilder::appendCaseToSwitch(theSwitch, ValueBuilder::makeNum(i)); ValueBuilder::appendCodeToSwitch(theSwitch, blockify(ValueBuilder::makeBreak(fromName(curr->targets[i]))), false); @@ -887,23 +935,33 @@ Ref Wasm2AsmBuilder::processFunctionBody(Function* func, IString result) { return ret; } - Ref visitCall(Call *curr) { - Ref theCall = ValueBuilder::makeCall(fromName(curr->target)); + Ref visitGenericCall(Expression* curr, Name target, + ExpressionList& operands) { + Ref theCall = ValueBuilder::makeCall(fromName(target)); if (!isStatement(curr)) { - // none of our operands is a statement; go right ahead and create a simple expression - for (auto operand : curr->operands) { - theCall[2]->push_back(makeAsmCoercion(visit(operand, EXPRESSION_RESULT), wasmToAsmType(operand->type))); + // none of our operands is a statement; go right ahead and create a + // simple expression + for (auto operand : operands) { + theCall[2]->push_back( + makeAsmCoercion(visit(operand, EXPRESSION_RESULT), + wasmToAsmType(operand->type))); } return makeAsmCoercion(theCall, wasmToAsmType(curr->type)); } // we must statementize them all - return makeStatementizedCall(curr->operands, ValueBuilder::makeBlock(), theCall, result, curr->type); + return makeStatementizedCall(operands, ValueBuilder::makeBlock(), theCall, + result, curr->type); } - Ref visitCallImport(CallImport *curr) { - std::cerr << "visitCallImport not implemented yet" << std::endl; - abort(); + + Ref visitCall(Call* curr) { + return visitGenericCall(curr, curr->target, curr->operands); + } + + Ref visitCallImport(CallImport* curr) { + return visitGenericCall(curr, curr->target, curr->operands); } - Ref visitCallIndirect(CallIndirect *curr) { + + Ref visitCallIndirect(CallIndirect* curr) { std::string stable = std::string("FUNCTION_TABLE_") + curr->fullType.c_str(); IString table = IString(stable.c_str(), false); auto makeTableCall = [&](Ref target) { @@ -927,25 +985,49 @@ Ref Wasm2AsmBuilder::processFunctionBody(Function* func, IString result) { Ref theCall = makeTableCall(temp.getAstName()); return makeStatementizedCall(curr->operands, ret, theCall, result, curr->type); } - Ref visitGetLocal(GetLocal *curr) { - return ValueBuilder::makeName(fromName(func->getLocalNameOrGeneric(curr->index))); - } - Ref visitSetLocal(SetLocal *curr) { + + Ref makeSetVar(Expression* curr, Expression* value, Name name) { if (!isStatement(curr)) { return ValueBuilder::makeBinary( - ValueBuilder::makeName(fromName(func->getLocalNameOrGeneric(curr->index))), - SET, visit(curr->value, EXPRESSION_RESULT)); + ValueBuilder::makeName(fromName(name)), SET, + visit(value, EXPRESSION_RESULT) + ); } - ScopedTemp temp(curr->type, parent, func, result); // if result was provided, our child can just assign there. otherwise, allocate a temp for it to assign to. - Ref ret = blockify(visit(curr->value, temp)); + // if result was provided, our child can just assign there. + // Otherwise, allocate a temp for it to assign to. + ScopedTemp temp(value->type, parent, func, result); + Ref ret = blockify(visit(value, temp)); // the output was assigned to result, so we can just assign it to our target - ret[1]->push_back(ValueBuilder::makeStatement( + ret[1]->push_back( + ValueBuilder::makeStatement( ValueBuilder::makeBinary( - ValueBuilder::makeName(fromName(func->getLocalNameOrGeneric(curr->index))), - SET, temp.getAstName()))); + ValueBuilder::makeName(fromName(name)), SET, + temp.getAstName() + ) + ) + ); return ret; } - Ref visitLoad(Load *curr) { + + Ref visitGetLocal(GetLocal* curr) { + return ValueBuilder::makeName( + fromName(func->getLocalNameOrGeneric(curr->index)) + ); + } + + Ref visitSetLocal(SetLocal* curr) { + return makeSetVar(curr, curr->value, func->getLocalNameOrGeneric(curr->index)); + } + + Ref visitGetGlobal(GetGlobal* curr) { + return ValueBuilder::makeName(fromName(curr->name)); + } + + Ref visitSetGlobal(SetGlobal* curr) { + return makeSetVar(curr, curr->value, curr->name); + } + + Ref visitLoad(Load* curr) { if (isStatement(curr)) { ScopedTemp temp(i32, parent, func); GetLocal fakeLocal(allocator); @@ -1037,7 +1119,8 @@ Ref Wasm2AsmBuilder::processFunctionBody(Function* func, IString result) { } return makeAsmCoercion(ret, wasmToAsmType(curr->type)); } - Ref visitStore(Store *curr) { + + Ref visitStore(Store* curr) { if (isStatement(curr)) { ScopedTemp tempPtr(i32, parent, func); ScopedTemp tempValue(curr->valueType, parent, func); @@ -1085,7 +1168,7 @@ Ref Wasm2AsmBuilder::processFunctionBody(Function* func, IString result) { shift.value = Literal(int32_t(8*i)); shift.type = i32; Binary shifted(allocator); - shifted.op = ShrUInt64; + shifted.op = ShrUInt32; shifted.left = &getValue; shifted.right = &shift; shifted.type = i32; @@ -1140,11 +1223,13 @@ Ref Wasm2AsmBuilder::processFunctionBody(Function* func, IString result) { } return ValueBuilder::makeBinary(ret, SET, value); } - Ref visitDrop(Drop *curr) { + + Ref visitDrop(Drop* curr) { assert(!isStatement(curr)); return visitAndAssign(curr->value, result); } - Ref visitConst(Const *curr) { + + Ref visitConst(Const* curr) { switch (curr->type) { case i32: return ValueBuilder::makeInt(curr->value.geti32()); case f32: { @@ -1165,7 +1250,8 @@ Ref Wasm2AsmBuilder::processFunctionBody(Function* func, IString result) { default: abort(); } } - Ref visitUnary(Unary *curr) { + + Ref visitUnary(Unary* curr) { if (isStatement(curr)) { ScopedTemp temp(curr->value->type, parent, func); GetLocal fakeLocal(allocator); @@ -1267,12 +1353,13 @@ Ref Wasm2AsmBuilder::processFunctionBody(Function* func, IString result) { return ret; } default: { - std::cerr << "Unhandled type: " << curr << std::endl; + std::cerr << "Unhandled type in unary: " << curr << std::endl; abort(); } } } - Ref visitBinary(Binary *curr) { + + Ref visitBinary(Binary* curr) { if (isStatement(curr)) { ScopedTemp tempLeft(curr->left->type, parent, func); GetLocal fakeLocalLeft(allocator); @@ -1399,7 +1486,8 @@ Ref Wasm2AsmBuilder::processFunctionBody(Function* func, IString result) { } return makeAsmCoercion(ret, wasmToAsmType(curr->type)); } - Ref visitSelect(Select *curr) { + + Ref visitSelect(Select* curr) { if (isStatement(curr)) { ScopedTemp tempIfTrue(curr->ifTrue->type, parent, func); GetLocal fakeLocalIfTrue(allocator); @@ -1434,21 +1522,35 @@ Ref Wasm2AsmBuilder::processFunctionBody(Function* func, IString result) { ValueBuilder::makeBinary(tempIfTrue.getAstName(), SET, ifTrue), ValueBuilder::makeSeq( ValueBuilder::makeBinary(tempIfFalse.getAstName(), SET, ifFalse), - ValueBuilder::makeConditional(tempCondition.getAstName(), tempIfTrue.getAstName(), tempIfFalse.getAstName()) + ValueBuilder::makeConditional( + tempCondition.getAstName(), + tempIfTrue.getAstName(), + tempIfFalse.getAstName() + ) ) ) ); } - Ref visitReturn(Return *curr) { - abort(); + + Ref visitReturn(Return* curr) { + Ref val = (curr->value == nullptr) ? + Ref() : + makeAsmCoercion( + visit(curr->value, NO_RESULT), + wasmToAsmType(curr->value->type) + ); + return ValueBuilder::makeReturn(val); } - Ref visitHost(Host *curr) { + + Ref visitHost(Host* curr) { abort(); } - Ref visitNop(Nop *curr) { + + Ref visitNop(Nop* curr) { return ValueBuilder::makeToplevel(); } - Ref visitUnreachable(Unreachable *curr) { + + Ref visitUnreachable(Unreachable* curr) { return ValueBuilder::makeCall(ABORT_FUNC); } }; @@ -1497,27 +1599,43 @@ Ref Wasm2AsmBuilder::makeAssertReturnFunc(SExpressionWasmBuilder& sexpBuilder, Builder& wasmBuilder, Element& e, Name testFuncName) { Expression* actual = sexpBuilder.parseExpression(e[1]); - Expression* expected = sexpBuilder.parseExpression(e[2]); - WasmType resType = expected->type; - actual->type = resType; - BinaryOp eqOp; - switch (resType) { - case i32: eqOp = EqInt32; break; - case i64: eqOp = EqInt64; break; - case f32: eqOp = EqFloat32; break; - case f64: eqOp = EqFloat64; break; - default: { - std::cerr << "Unhandled type in assert: " << resType << std::endl; - abort(); + Expression* body = nullptr; + if (e.size() == 2) { + if (actual->type == none) { + body = wasmBuilder.blockify( + actual, + wasmBuilder.makeConst(Literal(uint32_t(1))) + ); + } else { + body = actual; + } + } else if (e.size() == 3) { + Expression* expected = sexpBuilder.parseExpression(e[2]); + WasmType resType = expected->type; + actual->type = resType; + BinaryOp eqOp; + switch (resType) { + case i32: eqOp = EqInt32; break; + case i64: eqOp = EqInt64; break; + case f32: eqOp = EqFloat32; break; + case f64: eqOp = EqFloat64; break; + default: { + std::cerr << "Unhandled type in assert: " << resType << std::endl; + abort(); + } } + body = wasmBuilder.makeBinary(eqOp, actual, expected); + } else { + assert(false && "Unexpected number of parameters in assert_return"); } - Binary* test = wasmBuilder.makeBinary(eqOp, actual, expected); std::unique_ptr<Function> testFunc( - wasmBuilder.makeFunction(testFuncName, - std::vector<NameType>{}, - i32, - std::vector<NameType>{}, - test) + wasmBuilder.makeFunction( + testFuncName, + std::vector<NameType>{}, + body->type, + std::vector<NameType>{}, + body + ) ); Ref jsFunc = processFunction(testFunc.get()); prefixCalls(jsFunc); |