diff options
Diffstat (limited to 'src/wasm2asm.h')
-rw-r--r-- | src/wasm2asm.h | 243 |
1 files changed, 217 insertions, 26 deletions
diff --git a/src/wasm2asm.h b/src/wasm2asm.h index 4ad81816c..e2bb2bca7 100644 --- a/src/wasm2asm.h +++ b/src/wasm2asm.h @@ -27,6 +27,7 @@ #include "asmjs/shared-constants.h" #include "wasm.h" +#include "wasm-builder.h" #include "emscripten-optimizer/optimizer.h" #include "mixed_arena.h" #include "asm_v_wasm.h" @@ -44,10 +45,10 @@ IString ASM_FUNC("asmFunc"), // Appends extra to block, flattening out if extra is a block as well void flattenAppend(Ref ast, Ref extra) { int index; - if (ast[0] == BLOCK) index = 1; - else if (ast[0] == DEFUN) index = 3; + if (ast->isArray() && ast[0] == BLOCK) index = 1; + else if (ast->isArray() && ast[0] == DEFUN) index = 3; else abort(); - if (extra[0] == BLOCK) { + if (extra->isArray() && extra[0] == BLOCK) { for (size_t i = 0; i < extra[1]->size(); i++) { ast[index]->push_back(extra[1][i]); } @@ -190,13 +191,174 @@ private: void addImport(Ref ast, Import *import); void addTables(Ref ast, Module *wasm); void addExports(Ref ast, Module *wasm); + void addWasmCompatibilityFuncs(Module *wasm); Wasm2AsmBuilder() = delete; Wasm2AsmBuilder(const Wasm2AsmBuilder &) = delete; - Wasm2AsmBuilder &operator=(const Wasm2AsmBuilder &) = delete; + Wasm2AsmBuilder &operator=(const Wasm2AsmBuilder&) = delete; }; +static Function* makeCtzFunc(MixedArena& allocator, UnaryOp op) { + assert(op == CtzInt32 || op == CtzInt64); + Builder b(allocator); + // if eqz(x) then 32 else (32 - clz(x ^ (x - 1))) + bool is32Bit = (op == CtzInt32); + Name funcName = is32Bit ? Name(CTZ32) : Name(CTZ64); + BinaryOp subOp = is32Bit ? SubInt32 : SubInt64; + BinaryOp xorOp = is32Bit ? XorInt32 : XorInt64; + UnaryOp clzOp = is32Bit ? ClzInt32 : ClzInt64; + UnaryOp eqzOp = is32Bit ? EqZInt32 : EqZInt64; + WasmType argType = is32Bit ? i32 : i64; + Binary* xorExp = b.makeBinary( + xorOp, + b.makeGetLocal(0, i32), + b.makeBinary( + subOp, + b.makeGetLocal(0, i32), + b.makeConst(is32Bit ? Literal(int32_t(1)) : Literal(int64_t(1))) + ) + ); + Binary* subExp = b.makeBinary( + subOp, + b.makeConst(is32Bit ? Literal(int32_t(32 - 1)) : Literal(int64_t(64 - 1))), + b.makeUnary(clzOp, xorExp) + ); + If* body = b.makeIf( + b.makeUnary( + eqzOp, + b.makeGetLocal(0, i32) + ), + b.makeConst(is32Bit ? Literal(int32_t(32)) : Literal(int64_t(64))), + subExp + ); + return b.makeFunction( + funcName, + std::vector<NameType>{NameType("x", argType)}, + argType, + std::vector<NameType>{}, + body + ); +} + +static Function* makePopcntFunc(MixedArena& allocator, UnaryOp op) { + assert(op == PopcntInt32 || op == PopcntInt64); + Builder b(allocator); + // popcnt implemented as: + // int c; for (c = 0; x != 0; c++) { x = x & (x - 1) }; return c + bool is32Bit = (op == PopcntInt32); + Name funcName = is32Bit ? Name(POPCNT32) : Name(POPCNT64); + BinaryOp addOp = is32Bit ? AddInt32 : AddInt64; + BinaryOp subOp = is32Bit ? SubInt32 : SubInt64; + BinaryOp andOp = is32Bit ? AndInt32 : AndInt64; + UnaryOp eqzOp = is32Bit ? EqZInt32 : EqZInt64; + WasmType argType = is32Bit ? i32 : i64; + Name loopName("l"); + Name blockName("b"); + Break* brIf = b.makeBreak( + blockName, + b.makeGetLocal(1, i32), + b.makeUnary( + eqzOp, + b.makeGetLocal(0, argType) + ) + ); + SetLocal* update = b.makeSetLocal( + 0, + b.makeBinary( + andOp, + b.makeGetLocal(0, argType), + b.makeBinary( + subOp, + b.makeGetLocal(0, argType), + b.makeConst(is32Bit ? Literal(int32_t(1)) : Literal(int64_t(1))) + ) + ) + ); + SetLocal* inc = b.makeSetLocal( + 1, + b.makeBinary( + addOp, + b.makeGetLocal(1, argType), + b.makeConst(Literal(1)) + ) + ); + Break* cont = b.makeBreak(loopName); + Loop* loop = b.makeLoop(loopName, b.blockify(brIf, update, inc, cont)); + Block* loopBlock = b.blockifyWithName(loop, blockName); + SetLocal* initCount = b.makeSetLocal(1, b.makeConst(Literal(0))); + return b.makeFunction( + funcName, + std::vector<NameType>{NameType("x", argType)}, + argType, + std::vector<NameType>{NameType("count", argType)}, + b.blockify(initCount, loopBlock) + ); +} + +Function* makeRotFunc(MixedArena& allocator, BinaryOp op) { + assert(op == RotLInt32 || op == RotRInt32 || + op == RotLInt64 || op == RotRInt64); + Builder b(allocator); + // left rotate is: + // (((((~0) >>> k) & x) << k) | ((((~0) << (w - k)) & x) >>> (w - k))) + // where k is shift modulo w. reverse shifts for right rotate + bool is32Bit = (op == RotLInt32 || op == RotRInt32); + bool isLRot = (op == RotLInt32 || op == RotLInt64); + static Name names[2][2] = {{Name(ROTR64), Name(ROTR32)}, + {Name(ROTL64), Name(ROTL32)}}; + static BinaryOp shifters[2][2] = {{ShrUInt64, ShrUInt32}, + {ShlInt64, ShlInt32}}; + Name funcName = names[isLRot][is32Bit]; + BinaryOp lshift = shifters[isLRot][is32Bit]; + BinaryOp rshift = shifters[!isLRot][is32Bit]; + BinaryOp orOp = is32Bit ? OrInt32 : OrInt64; + BinaryOp andOp = is32Bit ? AndInt32 : AndInt64; + BinaryOp subOp = is32Bit ? SubInt32 : SubInt64; + WasmType argType = is32Bit ? i32 : i64; + Literal widthMask = + is32Bit ? Literal(int32_t(32 - 1)) : Literal(int64_t(64 - 1)); + Literal width = + is32Bit ? Literal(int32_t(32)) : Literal(int64_t(64)); + auto shiftVal = [&]() { + return b.makeBinary( + andOp, + b.makeGetLocal(1, argType), + b.makeConst(widthMask) + ); + }; + auto widthSub = [&]() { + return b.makeBinary(subOp, b.makeConst(width), shiftVal()); + }; + auto fullMask = [&]() { + return b.makeConst(is32Bit ? Literal(~int32_t(0)) : Literal(~int64_t(0))); + }; + Binary* maskRShift = b.makeBinary(rshift, fullMask(), shiftVal()); + Binary* lowMask = b.makeBinary(andOp, maskRShift, b.makeGetLocal(0, argType)); + Binary* lowShift = b.makeBinary(lshift, lowMask, shiftVal()); + Binary* maskLShift = b.makeBinary(lshift, fullMask(), widthSub()); + Binary* highMask = + b.makeBinary(andOp, maskLShift, b.makeGetLocal(0, argType)); + Binary* highShift = b.makeBinary(rshift, highMask, widthSub()); + Binary* body = b.makeBinary(orOp, lowShift, highShift); + return b.makeFunction( + funcName, + std::vector<NameType>{NameType("x", argType), + NameType("k", argType)}, + argType, + std::vector<NameType>{}, + body + ); +} + +void Wasm2AsmBuilder::addWasmCompatibilityFuncs(Module* wasm) { + wasm->addFunction(makeCtzFunc(wasm->allocator, CtzInt32)); + wasm->addFunction(makePopcntFunc(wasm->allocator, PopcntInt32)); + wasm->addFunction(makeRotFunc(wasm->allocator, RotLInt32)); + wasm->addFunction(makeRotFunc(wasm->allocator, RotRInt32)); +} + Ref Wasm2AsmBuilder::processWasm(Module* wasm) { + addWasmCompatibilityFuncs(wasm); Ref ret = ValueBuilder::makeToplevel(); Ref asmFunc = ValueBuilder::makeFunction(ASM_FUNC); ret[1]->push_back(asmFunc); @@ -642,7 +804,8 @@ Ref Wasm2AsmBuilder::processFunctionBody(Function* func, IString result) { Name asmLabel = curr->name; continueLabels.insert(asmLabel); Ref body = visit(curr->body, result); - Ref ret = ValueBuilder::makeDo(body, ValueBuilder::makeInt(0)); + flattenAppend(body, ValueBuilder::makeBreak(asmLabel)); + Ref ret = ValueBuilder::makeDo(body, ValueBuilder::makeInt(1)); return ValueBuilder::makeLabel(fromName(asmLabel), ret); } Ref visitBreak(Break *curr) { @@ -995,15 +1158,28 @@ Ref Wasm2AsmBuilder::processFunctionBody(Function* func, IString result) { return ret; } // normal unary - Ref value = visit(curr->value, EXPRESSION_RESULT); switch (curr->type) { case i32: { switch (curr->op) { case ClzInt32: - return ValueBuilder::makeCall(MATH_CLZ32, value); + return ValueBuilder::makeCall(MATH_CLZ32, + visit(curr->value, + EXPRESSION_RESULT)); + case CtzInt32: + return ValueBuilder::makeCall(CTZ32, visit(curr->value, + EXPRESSION_RESULT)); case PopcntInt32: - return ValueBuilder::makeCall(MATH_POPCNT32, value); - default: abort(); + return ValueBuilder::makeCall(POPCNT32, visit(curr->value, + EXPRESSION_RESULT)); + case EqZInt32: + return ValueBuilder::makeBinary( + makeAsmCoercion(visit(curr->value, + EXPRESSION_RESULT), ASM_INT), EQ, + makeAsmCoercion(ValueBuilder::makeInt(0), ASM_INT)); + default: + std::cerr << "Unhandled unary i32 operator: " << curr + << std::endl; + abort(); } } case f32: @@ -1012,41 +1188,52 @@ Ref Wasm2AsmBuilder::processFunctionBody(Function* func, IString result) { switch (curr->op) { case NegFloat32: case NegFloat64: - ret = ValueBuilder::makeUnary(MINUS, visit(curr->value, - EXPRESSION_RESULT)); + ret = ValueBuilder::makeUnary( + MINUS, + visit(curr->value, EXPRESSION_RESULT) + ); break; case AbsFloat32: case AbsFloat64: - ret = ValueBuilder::makeCall(MATH_ABS, visit(curr->value, - EXPRESSION_RESULT)); + ret = ValueBuilder::makeCall( + MATH_ABS, + visit(curr->value, EXPRESSION_RESULT) + ); break; case CeilFloat32: case CeilFloat64: - ret = ValueBuilder::makeCall(MATH_CEIL, visit(curr->value, - EXPRESSION_RESULT)); + ret = ValueBuilder::makeCall( + MATH_CEIL, + visit(curr->value, EXPRESSION_RESULT) + ); break; case FloorFloat32: case FloorFloat64: - ret = ValueBuilder::makeCall(MATH_FLOOR, - visit(curr->value, - EXPRESSION_RESULT)); + ret = ValueBuilder::makeCall( + MATH_FLOOR, + visit(curr->value, EXPRESSION_RESULT) + ); break; case TruncFloat32: case TruncFloat64: - ret = ValueBuilder::makeCall(MATH_TRUNC, - visit(curr->value, - EXPRESSION_RESULT)); + ret = ValueBuilder::makeCall( + MATH_TRUNC, + visit(curr->value, EXPRESSION_RESULT) + ); break; case NearestFloat32: case NearestFloat64: - ret = ValueBuilder::makeCall(MATH_NEAREST, - visit(curr->value, - EXPRESSION_RESULT)); + ret = ValueBuilder::makeCall( + MATH_NEAREST, + visit(curr->value,EXPRESSION_RESULT) + ); break; case SqrtFloat32: case SqrtFloat64: - ret = ValueBuilder::makeCall(MATH_SQRT, visit(curr->value, - EXPRESSION_RESULT)); + ret = ValueBuilder::makeCall( + MATH_SQRT, + visit(curr->value, EXPRESSION_RESULT) + ); break; // TODO: more complex unary conversions default: @@ -1180,6 +1367,10 @@ Ref Wasm2AsmBuilder::processFunctionBody(Function* func, IString result) { case GeUInt32: return ValueBuilder::makeBinary(makeSigning(left, ASM_UNSIGNED), GE, makeSigning(right, ASM_UNSIGNED)); + case RotLInt32: + return ValueBuilder::makeCall(ROTL32, left, right); + case RotRInt32: + return ValueBuilder::makeCall(ROTR32, left, right); default: std::cerr << "Unhandled binary operator: " << curr << std::endl; abort(); |