diff options
Diffstat (limited to 'src/passes/RemoveNonJSOps.cpp')
-rw-r--r-- | src/passes/RemoveNonJSOps.cpp | 406 |
1 files changed, 90 insertions, 316 deletions
diff --git a/src/passes/RemoveNonJSOps.cpp b/src/passes/RemoveNonJSOps.cpp index 77809d6b3..4bd40a6c6 100644 --- a/src/passes/RemoveNonJSOps.cpp +++ b/src/passes/RemoveNonJSOps.cpp @@ -15,341 +15,97 @@ */ // -// // Removes all operations in a wasm module that aren't inherently implementable -// in JS. This includes things like `f32.nearest` and -// `f64.copysign`. Most operations are lowered to a call to an injected +// in JS. This includes things like 64-bit division, `f32.nearest`, +// `f64.copysign`, etc. Most operations are lowered to a call to an injected // intrinsic implementation. Intrinsics don't use themselves to implement // themselves. // +// You'll find a large wast blob in `wasm-intrinsics.wast` next to this file +// which contains all of the injected intrinsics. We manually copy over any +// needed intrinsics from this module into the module that we're optimizing +// after walking the current module. +// #include <wasm.h> #include <pass.h> #include "asmjs/shared-constants.h" #include "wasm-builder.h" +#include "wasm-s-parser.h" +#include "ir/module-utils.h" +#include "ir/find_all.h" +#include "passes/intrinsics-module.h" namespace wasm { struct RemoveNonJSOpsPass : public WalkerPass<PostWalker<RemoveNonJSOpsPass>> { - bool needNearestF32 = false; - bool needNearestF64 = false; - bool needTruncF32 = false; - bool needTruncF64 = false; - bool needCtzInt32 = false; - bool needPopcntInt32 = false; - bool needRotLInt32 = false; - bool needRotRInt32 = false; + std::unique_ptr<Builder> builder; + std::unordered_set<Name> neededIntrinsics; bool isFunctionParallel() override { return false; } Pass* create() override { return new RemoveNonJSOpsPass; } void doWalkModule(Module* module) { + // Discover all of the intrinsics that we need to inject, lowering all + // operations to intrinsic calls while we're at it. if (!builder) builder = make_unique<Builder>(*module); PostWalker<RemoveNonJSOpsPass>::doWalkModule(module); - if (needNearestF32) { - module->addFunction(createNearest(f32)); - } - if (needNearestF64) { - module->addFunction(createNearest(f64)); - } - if (needTruncF32) { - module->addFunction(createTrunc(f32)); - } - if (needTruncF64) { - module->addFunction(createTrunc(f64)); - } - if (needCtzInt32) { - module->addFunction(createCtz()); - } - if (needPopcntInt32) { - module->addFunction(createPopcnt()); + if (neededIntrinsics.size() == 0) { + return; } - if (needRotLInt32) { - module->addFunction(createRot(RotLInt32)); - } - if (needRotRInt32) { - module->addFunction(createRot(RotRInt32)); - } - } - - Function *createNearest(Type f) { - // fn nearest(f: float) -> float { - // let ceil = ceil(f); - // let floor = floor(f); - // let fract = f - floor; - // if fract < 0.5 { - // floor - // } else if fract > 0.5 { - // ceil - // } else { - // let rem = floor / 2.0; - // if rem - floor(rem) == 0.0 { - // floor - // } else { - // ceil - // } - // } - // } - Index arg = 0; - Index ceil = 1; - Index floor = 2; - Index fract = 3; - Index rem = 4; - - UnaryOp ceilOp = CeilFloat32; - UnaryOp floorOp = FloorFloat32; - BinaryOp subOp = SubFloat32; - BinaryOp ltOp = LtFloat32; - BinaryOp gtOp = GtFloat32; - BinaryOp divOp = DivFloat32; - BinaryOp eqOp = EqFloat32; - Literal litHalf((float) 0.5); - Literal litOne((float) 1.0); - Literal litZero((float) 0.0); - Literal litTwo((float) 2.0); - if (f == f64) { - ceilOp = CeilFloat64; - floorOp = FloorFloat64; - subOp = SubFloat64; - ltOp = LtFloat64; - gtOp = GtFloat64; - divOp = DivFloat64; - eqOp = EqFloat64; - litHalf = Literal((double) 0.5); - litOne = Literal((double) 1.0); - litZero = Literal((double) 0.0); - litTwo = Literal((double) 2.0); + // Parse the wast blob we have at the end of this file. + // + // TODO: only do this once per invocation of wasm2asm + Module intrinsicsModule; + std::string input(IntrinsicsModuleWast); + SExpressionParser parser(const_cast<char*>(input.c_str())); + Element& root = *parser.root; + SExpressionWasmBuilder builder(intrinsicsModule, *root[0]); + + std::set<Name> neededFunctions; + + // Iteratively link intrinsics from `intrinsicsModule` into our destination + // module, as needed. + // + // Note that intrinsics often use one another. For example the 64-bit + // division intrinsic ends up using the 32-bit ctz intrinsic, but does so + // via a native instruction. The loop here is used to continuously reprocess + // injected intrinsics to ensure that they never contain non-js ops when + // we're done. + while (neededIntrinsics.size() > 0) { + // Recursively probe all needed intrinsics for transitively used + // functions. This is building up a set of functions we'll link into our + // module. + for (auto &name : neededIntrinsics) { + addNeededFunctions(intrinsicsModule, name, neededFunctions); + } + neededIntrinsics.clear(); + + // Link in everything that wasn't already linked in. After we've done the + // copy we then walk the function to rewrite any non-js operations it has + // as well. + for (auto &name : neededFunctions) { + doWalkFunction(ModuleUtils::copyFunction(intrinsicsModule, *module, name)); + } + neededFunctions.clear(); } - - Expression *body = builder->blockify( - builder->makeSetLocal( - ceil, - builder->makeUnary(ceilOp, builder->makeGetLocal(arg, f)) - ), - builder->makeSetLocal( - floor, - builder->makeUnary(floorOp, builder->makeGetLocal(arg, f)) - ), - builder->makeSetLocal( - fract, - builder->makeBinary( - subOp, - builder->makeGetLocal(arg, f), - builder->makeGetLocal(floor, f) - ) - ), - builder->makeIf( - builder->makeBinary( - ltOp, - builder->makeGetLocal(fract, f), - builder->makeConst(litHalf) - ), - builder->makeGetLocal(floor, f), - builder->makeIf( - builder->makeBinary( - gtOp, - builder->makeGetLocal(fract, f), - builder->makeConst(litHalf) - ), - builder->makeGetLocal(ceil, f), - builder->blockify( - builder->makeSetLocal( - rem, - builder->makeBinary( - divOp, - builder->makeGetLocal(floor, f), - builder->makeConst(litTwo) - ) - ), - builder->makeIf( - builder->makeBinary( - eqOp, - builder->makeBinary( - subOp, - builder->makeGetLocal(rem, f), - builder->makeUnary( - floorOp, - builder->makeGetLocal(rem, f) - ) - ), - builder->makeConst(litZero) - ), - builder->makeGetLocal(floor, f), - builder->makeGetLocal(ceil, f) - ) - ) - ) - ) - ); - std::vector<Type> params = {f}; - std::vector<Type> vars = {f, f, f, f, f}; - Name name = f == f32 ? WASM_NEAREST_F32 : WASM_NEAREST_F64; - return builder->makeFunction(name, std::move(params), f, std::move(vars), body); } - Function *createTrunc(Type f) { - // fn trunc(f: float) -> float { - // if f < 0.0 { - // ceil(f) - // } else { - // floor(f) - // } - // } - - Index arg = 0; - - UnaryOp ceilOp = CeilFloat32; - UnaryOp floorOp = FloorFloat32; - BinaryOp ltOp = LtFloat32; - Literal litZero((float) 0.0); - if (f == f64) { - ceilOp = CeilFloat64; - floorOp = FloorFloat64; - ltOp = LtFloat64; - litZero = Literal((double) 0.0); + void addNeededFunctions(Module &m, Name name, std::set<Name> &needed) { + if (needed.count(name)) { + return; } + needed.insert(name); - Expression *body = builder->makeIf( - builder->makeBinary( - ltOp, - builder->makeGetLocal(arg, f), - builder->makeConst(litZero) - ), - builder->makeUnary(ceilOp, builder->makeGetLocal(arg, f)), - builder->makeUnary(floorOp, builder->makeGetLocal(arg, f)) - ); - std::vector<Type> params = {f}; - std::vector<Type> vars = {}; - Name name = f == f32 ? WASM_TRUNC_F32 : WASM_TRUNC_F64; - return builder->makeFunction(name, std::move(params), f, std::move(vars), body); - } - - Function* createCtz() { - // if eqz(x) then 32 else (32 - clz(x ^ (x - 1))) - Binary* xorExp = builder->makeBinary( - XorInt32, - builder->makeGetLocal(0, i32), - builder->makeBinary( - SubInt32, - builder->makeGetLocal(0, i32), - builder->makeConst(Literal(int32_t(1))) - ) - ); - Binary* subExp = builder->makeBinary( - SubInt32, - builder->makeConst(Literal(int32_t(32 - 1))), - builder->makeUnary(ClzInt32, xorExp) - ); - If* body = builder->makeIf( - builder->makeUnary( - EqZInt32, - builder->makeGetLocal(0, i32) - ), - builder->makeConst(Literal(int32_t(32))), - subExp - ); - return builder->makeFunction( - WASM_CTZ32, - std::vector<NameType>{NameType("x", i32)}, - i32, - std::vector<NameType>{}, - body - ); - } - - Function* createPopcnt() { - // popcnt implemented as: - // int c; for (c = 0; x != 0; c++) { x = x & (x - 1) }; return c - Name loopName("l"); - Name blockName("b"); - Break* brIf = builder->makeBreak( - blockName, - builder->makeGetLocal(1, i32), - builder->makeUnary( - EqZInt32, - builder->makeGetLocal(0, i32) - ) - ); - SetLocal* update = builder->makeSetLocal( - 0, - builder->makeBinary( - AndInt32, - builder->makeGetLocal(0, i32), - builder->makeBinary( - SubInt32, - builder->makeGetLocal(0, i32), - builder->makeConst(Literal(int32_t(1))) - ) - ) - ); - SetLocal* inc = builder->makeSetLocal( - 1, - builder->makeBinary( - AddInt32, - builder->makeGetLocal(1, i32), - builder->makeConst(Literal(1)) - ) - ); - Break* cont = builder->makeBreak(loopName); - Loop* loop = builder->makeLoop( - loopName, - builder->blockify(builder->makeDrop(brIf), update, inc, cont) - ); - Block* loopBlock = builder->blockifyWithName(loop, blockName); - // TODO: not sure why this is necessary... - loopBlock->type = i32; - SetLocal* initCount = builder->makeSetLocal(1, builder->makeConst(Literal(0))); - return builder->makeFunction( - WASM_POPCNT32, - std::vector<NameType>{NameType("x", i32)}, - i32, - std::vector<NameType>{NameType("count", i32)}, - builder->blockify(initCount, loopBlock) - ); - } - - Function* createRot(BinaryOp op) { - // left rotate is: - // (((((~0) >>> k) & x) << k) | ((((~0) << (w - k)) & x) >>> (w - k))) - // where k is shift modulo w. reverse shifts for right rotate - bool isLRot = op == RotLInt32; - BinaryOp lshift = isLRot ? ShlInt32 : ShrUInt32; - BinaryOp rshift = isLRot ? ShrUInt32 : ShlInt32; - Literal widthMask(int32_t(32 - 1)); - Literal width(int32_t(32)); - auto shiftVal = [&]() { - return builder->makeBinary( - AndInt32, - builder->makeGetLocal(1, i32), - builder->makeConst(widthMask) - ); - }; - auto widthSub = [&]() { - return builder->makeBinary(SubInt32, builder->makeConst(width), shiftVal()); - }; - auto fullMask = [&]() { - return builder->makeConst(Literal(~int32_t(0))); - }; - Binary* maskRShift = builder->makeBinary(rshift, fullMask(), shiftVal()); - Binary* lowMask = builder->makeBinary(AndInt32, maskRShift, builder->makeGetLocal(0, i32)); - Binary* lowShift = builder->makeBinary(lshift, lowMask, shiftVal()); - Binary* maskLShift = builder->makeBinary(lshift, fullMask(), widthSub()); - Binary* highMask = - builder->makeBinary(AndInt32, maskLShift, builder->makeGetLocal(0, i32)); - Binary* highShift = builder->makeBinary(rshift, highMask, widthSub()); - Binary* body = builder->makeBinary(OrInt32, lowShift, highShift); - return builder->makeFunction( - isLRot ? WASM_ROTL32 : WASM_ROTR32, - std::vector<NameType>{NameType("x", i32), - NameType("k", i32)}, - i32, - std::vector<NameType>{}, - body - ); + auto function = m.getFunction(name); + FindAll<Call> calls(function->body); + for (auto &call : calls.list) { + this->addNeededFunctions(m, call->target, needed); + } } void doWalkFunction(Function* func) { @@ -366,16 +122,36 @@ struct RemoveNonJSOpsPass : public WalkerPass<PostWalker<RemoveNonJSOpsPass>> { return; case RotLInt32: - needRotLInt32 = true; name = WASM_ROTL32; break; case RotRInt32: - needRotRInt32 = true; name = WASM_ROTR32; break; + case RotLInt64: + name = WASM_ROTL64; + break; + case RotRInt64: + name = WASM_ROTR64; + break; + case MulInt64: + name = WASM_I64_MUL; + break; + case DivSInt64: + name = WASM_I64_SDIV; + break; + case DivUInt64: + name = WASM_I64_UDIV; + break; + case RemSInt64: + name = WASM_I64_SREM; + break; + case RemUInt64: + name = WASM_I64_UREM; + break; default: return; } + neededIntrinsics.insert(name); replaceCurrent(builder->makeCall(name, {curr->left, curr->right}, curr->type)); } @@ -435,40 +211,38 @@ struct RemoveNonJSOpsPass : public WalkerPass<PostWalker<RemoveNonJSOpsPass>> { Name functionCall; switch (curr->op) { case NearestFloat32: - needNearestF32 = true; functionCall = WASM_NEAREST_F32; break; case NearestFloat64: - needNearestF64 = true; functionCall = WASM_NEAREST_F64; break; case TruncFloat32: - needTruncF32 = true; functionCall = WASM_TRUNC_F32; break; case TruncFloat64: - needTruncF64 = true; functionCall = WASM_TRUNC_F64; break; + case PopcntInt64: + functionCall = WASM_POPCNT64; + break; case PopcntInt32: - needPopcntInt32 = true; functionCall = WASM_POPCNT32; break; + case CtzInt64: + functionCall = WASM_CTZ64; + break; case CtzInt32: - needCtzInt32 = true; functionCall = WASM_CTZ32; break; default: return; } + neededIntrinsics.insert(functionCall); replaceCurrent(builder->makeCall(functionCall, {curr->value}, curr->type)); } - -private: - std::unique_ptr<Builder> builder; }; Pass *createRemoveNonJSOpsPass() { |