summaryrefslogtreecommitdiff
path: root/src/passes/RemoveNonJSOps.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/passes/RemoveNonJSOps.cpp')
-rw-r--r--src/passes/RemoveNonJSOps.cpp406
1 files changed, 90 insertions, 316 deletions
diff --git a/src/passes/RemoveNonJSOps.cpp b/src/passes/RemoveNonJSOps.cpp
index 77809d6b3..4bd40a6c6 100644
--- a/src/passes/RemoveNonJSOps.cpp
+++ b/src/passes/RemoveNonJSOps.cpp
@@ -15,341 +15,97 @@
*/
//
-//
// Removes all operations in a wasm module that aren't inherently implementable
-// in JS. This includes things like `f32.nearest` and
-// `f64.copysign`. Most operations are lowered to a call to an injected
+// in JS. This includes things like 64-bit division, `f32.nearest`,
+// `f64.copysign`, etc. Most operations are lowered to a call to an injected
// intrinsic implementation. Intrinsics don't use themselves to implement
// themselves.
//
+// You'll find a large wast blob in `wasm-intrinsics.wast` next to this file
+// which contains all of the injected intrinsics. We manually copy over any
+// needed intrinsics from this module into the module that we're optimizing
+// after walking the current module.
+//
#include <wasm.h>
#include <pass.h>
#include "asmjs/shared-constants.h"
#include "wasm-builder.h"
+#include "wasm-s-parser.h"
+#include "ir/module-utils.h"
+#include "ir/find_all.h"
+#include "passes/intrinsics-module.h"
namespace wasm {
struct RemoveNonJSOpsPass : public WalkerPass<PostWalker<RemoveNonJSOpsPass>> {
- bool needNearestF32 = false;
- bool needNearestF64 = false;
- bool needTruncF32 = false;
- bool needTruncF64 = false;
- bool needCtzInt32 = false;
- bool needPopcntInt32 = false;
- bool needRotLInt32 = false;
- bool needRotRInt32 = false;
+ std::unique_ptr<Builder> builder;
+ std::unordered_set<Name> neededIntrinsics;
bool isFunctionParallel() override { return false; }
Pass* create() override { return new RemoveNonJSOpsPass; }
void doWalkModule(Module* module) {
+ // Discover all of the intrinsics that we need to inject, lowering all
+ // operations to intrinsic calls while we're at it.
if (!builder) builder = make_unique<Builder>(*module);
PostWalker<RemoveNonJSOpsPass>::doWalkModule(module);
- if (needNearestF32) {
- module->addFunction(createNearest(f32));
- }
- if (needNearestF64) {
- module->addFunction(createNearest(f64));
- }
- if (needTruncF32) {
- module->addFunction(createTrunc(f32));
- }
- if (needTruncF64) {
- module->addFunction(createTrunc(f64));
- }
- if (needCtzInt32) {
- module->addFunction(createCtz());
- }
- if (needPopcntInt32) {
- module->addFunction(createPopcnt());
+ if (neededIntrinsics.size() == 0) {
+ return;
}
- if (needRotLInt32) {
- module->addFunction(createRot(RotLInt32));
- }
- if (needRotRInt32) {
- module->addFunction(createRot(RotRInt32));
- }
- }
-
- Function *createNearest(Type f) {
- // fn nearest(f: float) -> float {
- // let ceil = ceil(f);
- // let floor = floor(f);
- // let fract = f - floor;
- // if fract < 0.5 {
- // floor
- // } else if fract > 0.5 {
- // ceil
- // } else {
- // let rem = floor / 2.0;
- // if rem - floor(rem) == 0.0 {
- // floor
- // } else {
- // ceil
- // }
- // }
- // }
- Index arg = 0;
- Index ceil = 1;
- Index floor = 2;
- Index fract = 3;
- Index rem = 4;
-
- UnaryOp ceilOp = CeilFloat32;
- UnaryOp floorOp = FloorFloat32;
- BinaryOp subOp = SubFloat32;
- BinaryOp ltOp = LtFloat32;
- BinaryOp gtOp = GtFloat32;
- BinaryOp divOp = DivFloat32;
- BinaryOp eqOp = EqFloat32;
- Literal litHalf((float) 0.5);
- Literal litOne((float) 1.0);
- Literal litZero((float) 0.0);
- Literal litTwo((float) 2.0);
- if (f == f64) {
- ceilOp = CeilFloat64;
- floorOp = FloorFloat64;
- subOp = SubFloat64;
- ltOp = LtFloat64;
- gtOp = GtFloat64;
- divOp = DivFloat64;
- eqOp = EqFloat64;
- litHalf = Literal((double) 0.5);
- litOne = Literal((double) 1.0);
- litZero = Literal((double) 0.0);
- litTwo = Literal((double) 2.0);
+ // Parse the wast blob we have at the end of this file.
+ //
+ // TODO: only do this once per invocation of wasm2asm
+ Module intrinsicsModule;
+ std::string input(IntrinsicsModuleWast);
+ SExpressionParser parser(const_cast<char*>(input.c_str()));
+ Element& root = *parser.root;
+ SExpressionWasmBuilder builder(intrinsicsModule, *root[0]);
+
+ std::set<Name> neededFunctions;
+
+ // Iteratively link intrinsics from `intrinsicsModule` into our destination
+ // module, as needed.
+ //
+ // Note that intrinsics often use one another. For example the 64-bit
+ // division intrinsic ends up using the 32-bit ctz intrinsic, but does so
+ // via a native instruction. The loop here is used to continuously reprocess
+ // injected intrinsics to ensure that they never contain non-js ops when
+ // we're done.
+ while (neededIntrinsics.size() > 0) {
+ // Recursively probe all needed intrinsics for transitively used
+ // functions. This is building up a set of functions we'll link into our
+ // module.
+ for (auto &name : neededIntrinsics) {
+ addNeededFunctions(intrinsicsModule, name, neededFunctions);
+ }
+ neededIntrinsics.clear();
+
+ // Link in everything that wasn't already linked in. After we've done the
+ // copy we then walk the function to rewrite any non-js operations it has
+ // as well.
+ for (auto &name : neededFunctions) {
+ doWalkFunction(ModuleUtils::copyFunction(intrinsicsModule, *module, name));
+ }
+ neededFunctions.clear();
}
-
- Expression *body = builder->blockify(
- builder->makeSetLocal(
- ceil,
- builder->makeUnary(ceilOp, builder->makeGetLocal(arg, f))
- ),
- builder->makeSetLocal(
- floor,
- builder->makeUnary(floorOp, builder->makeGetLocal(arg, f))
- ),
- builder->makeSetLocal(
- fract,
- builder->makeBinary(
- subOp,
- builder->makeGetLocal(arg, f),
- builder->makeGetLocal(floor, f)
- )
- ),
- builder->makeIf(
- builder->makeBinary(
- ltOp,
- builder->makeGetLocal(fract, f),
- builder->makeConst(litHalf)
- ),
- builder->makeGetLocal(floor, f),
- builder->makeIf(
- builder->makeBinary(
- gtOp,
- builder->makeGetLocal(fract, f),
- builder->makeConst(litHalf)
- ),
- builder->makeGetLocal(ceil, f),
- builder->blockify(
- builder->makeSetLocal(
- rem,
- builder->makeBinary(
- divOp,
- builder->makeGetLocal(floor, f),
- builder->makeConst(litTwo)
- )
- ),
- builder->makeIf(
- builder->makeBinary(
- eqOp,
- builder->makeBinary(
- subOp,
- builder->makeGetLocal(rem, f),
- builder->makeUnary(
- floorOp,
- builder->makeGetLocal(rem, f)
- )
- ),
- builder->makeConst(litZero)
- ),
- builder->makeGetLocal(floor, f),
- builder->makeGetLocal(ceil, f)
- )
- )
- )
- )
- );
- std::vector<Type> params = {f};
- std::vector<Type> vars = {f, f, f, f, f};
- Name name = f == f32 ? WASM_NEAREST_F32 : WASM_NEAREST_F64;
- return builder->makeFunction(name, std::move(params), f, std::move(vars), body);
}
- Function *createTrunc(Type f) {
- // fn trunc(f: float) -> float {
- // if f < 0.0 {
- // ceil(f)
- // } else {
- // floor(f)
- // }
- // }
-
- Index arg = 0;
-
- UnaryOp ceilOp = CeilFloat32;
- UnaryOp floorOp = FloorFloat32;
- BinaryOp ltOp = LtFloat32;
- Literal litZero((float) 0.0);
- if (f == f64) {
- ceilOp = CeilFloat64;
- floorOp = FloorFloat64;
- ltOp = LtFloat64;
- litZero = Literal((double) 0.0);
+ void addNeededFunctions(Module &m, Name name, std::set<Name> &needed) {
+ if (needed.count(name)) {
+ return;
}
+ needed.insert(name);
- Expression *body = builder->makeIf(
- builder->makeBinary(
- ltOp,
- builder->makeGetLocal(arg, f),
- builder->makeConst(litZero)
- ),
- builder->makeUnary(ceilOp, builder->makeGetLocal(arg, f)),
- builder->makeUnary(floorOp, builder->makeGetLocal(arg, f))
- );
- std::vector<Type> params = {f};
- std::vector<Type> vars = {};
- Name name = f == f32 ? WASM_TRUNC_F32 : WASM_TRUNC_F64;
- return builder->makeFunction(name, std::move(params), f, std::move(vars), body);
- }
-
- Function* createCtz() {
- // if eqz(x) then 32 else (32 - clz(x ^ (x - 1)))
- Binary* xorExp = builder->makeBinary(
- XorInt32,
- builder->makeGetLocal(0, i32),
- builder->makeBinary(
- SubInt32,
- builder->makeGetLocal(0, i32),
- builder->makeConst(Literal(int32_t(1)))
- )
- );
- Binary* subExp = builder->makeBinary(
- SubInt32,
- builder->makeConst(Literal(int32_t(32 - 1))),
- builder->makeUnary(ClzInt32, xorExp)
- );
- If* body = builder->makeIf(
- builder->makeUnary(
- EqZInt32,
- builder->makeGetLocal(0, i32)
- ),
- builder->makeConst(Literal(int32_t(32))),
- subExp
- );
- return builder->makeFunction(
- WASM_CTZ32,
- std::vector<NameType>{NameType("x", i32)},
- i32,
- std::vector<NameType>{},
- body
- );
- }
-
- Function* createPopcnt() {
- // popcnt implemented as:
- // int c; for (c = 0; x != 0; c++) { x = x & (x - 1) }; return c
- Name loopName("l");
- Name blockName("b");
- Break* brIf = builder->makeBreak(
- blockName,
- builder->makeGetLocal(1, i32),
- builder->makeUnary(
- EqZInt32,
- builder->makeGetLocal(0, i32)
- )
- );
- SetLocal* update = builder->makeSetLocal(
- 0,
- builder->makeBinary(
- AndInt32,
- builder->makeGetLocal(0, i32),
- builder->makeBinary(
- SubInt32,
- builder->makeGetLocal(0, i32),
- builder->makeConst(Literal(int32_t(1)))
- )
- )
- );
- SetLocal* inc = builder->makeSetLocal(
- 1,
- builder->makeBinary(
- AddInt32,
- builder->makeGetLocal(1, i32),
- builder->makeConst(Literal(1))
- )
- );
- Break* cont = builder->makeBreak(loopName);
- Loop* loop = builder->makeLoop(
- loopName,
- builder->blockify(builder->makeDrop(brIf), update, inc, cont)
- );
- Block* loopBlock = builder->blockifyWithName(loop, blockName);
- // TODO: not sure why this is necessary...
- loopBlock->type = i32;
- SetLocal* initCount = builder->makeSetLocal(1, builder->makeConst(Literal(0)));
- return builder->makeFunction(
- WASM_POPCNT32,
- std::vector<NameType>{NameType("x", i32)},
- i32,
- std::vector<NameType>{NameType("count", i32)},
- builder->blockify(initCount, loopBlock)
- );
- }
-
- Function* createRot(BinaryOp op) {
- // left rotate is:
- // (((((~0) >>> k) & x) << k) | ((((~0) << (w - k)) & x) >>> (w - k)))
- // where k is shift modulo w. reverse shifts for right rotate
- bool isLRot = op == RotLInt32;
- BinaryOp lshift = isLRot ? ShlInt32 : ShrUInt32;
- BinaryOp rshift = isLRot ? ShrUInt32 : ShlInt32;
- Literal widthMask(int32_t(32 - 1));
- Literal width(int32_t(32));
- auto shiftVal = [&]() {
- return builder->makeBinary(
- AndInt32,
- builder->makeGetLocal(1, i32),
- builder->makeConst(widthMask)
- );
- };
- auto widthSub = [&]() {
- return builder->makeBinary(SubInt32, builder->makeConst(width), shiftVal());
- };
- auto fullMask = [&]() {
- return builder->makeConst(Literal(~int32_t(0)));
- };
- Binary* maskRShift = builder->makeBinary(rshift, fullMask(), shiftVal());
- Binary* lowMask = builder->makeBinary(AndInt32, maskRShift, builder->makeGetLocal(0, i32));
- Binary* lowShift = builder->makeBinary(lshift, lowMask, shiftVal());
- Binary* maskLShift = builder->makeBinary(lshift, fullMask(), widthSub());
- Binary* highMask =
- builder->makeBinary(AndInt32, maskLShift, builder->makeGetLocal(0, i32));
- Binary* highShift = builder->makeBinary(rshift, highMask, widthSub());
- Binary* body = builder->makeBinary(OrInt32, lowShift, highShift);
- return builder->makeFunction(
- isLRot ? WASM_ROTL32 : WASM_ROTR32,
- std::vector<NameType>{NameType("x", i32),
- NameType("k", i32)},
- i32,
- std::vector<NameType>{},
- body
- );
+ auto function = m.getFunction(name);
+ FindAll<Call> calls(function->body);
+ for (auto &call : calls.list) {
+ this->addNeededFunctions(m, call->target, needed);
+ }
}
void doWalkFunction(Function* func) {
@@ -366,16 +122,36 @@ struct RemoveNonJSOpsPass : public WalkerPass<PostWalker<RemoveNonJSOpsPass>> {
return;
case RotLInt32:
- needRotLInt32 = true;
name = WASM_ROTL32;
break;
case RotRInt32:
- needRotRInt32 = true;
name = WASM_ROTR32;
break;
+ case RotLInt64:
+ name = WASM_ROTL64;
+ break;
+ case RotRInt64:
+ name = WASM_ROTR64;
+ break;
+ case MulInt64:
+ name = WASM_I64_MUL;
+ break;
+ case DivSInt64:
+ name = WASM_I64_SDIV;
+ break;
+ case DivUInt64:
+ name = WASM_I64_UDIV;
+ break;
+ case RemSInt64:
+ name = WASM_I64_SREM;
+ break;
+ case RemUInt64:
+ name = WASM_I64_UREM;
+ break;
default: return;
}
+ neededIntrinsics.insert(name);
replaceCurrent(builder->makeCall(name, {curr->left, curr->right}, curr->type));
}
@@ -435,40 +211,38 @@ struct RemoveNonJSOpsPass : public WalkerPass<PostWalker<RemoveNonJSOpsPass>> {
Name functionCall;
switch (curr->op) {
case NearestFloat32:
- needNearestF32 = true;
functionCall = WASM_NEAREST_F32;
break;
case NearestFloat64:
- needNearestF64 = true;
functionCall = WASM_NEAREST_F64;
break;
case TruncFloat32:
- needTruncF32 = true;
functionCall = WASM_TRUNC_F32;
break;
case TruncFloat64:
- needTruncF64 = true;
functionCall = WASM_TRUNC_F64;
break;
+ case PopcntInt64:
+ functionCall = WASM_POPCNT64;
+ break;
case PopcntInt32:
- needPopcntInt32 = true;
functionCall = WASM_POPCNT32;
break;
+ case CtzInt64:
+ functionCall = WASM_CTZ64;
+ break;
case CtzInt32:
- needCtzInt32 = true;
functionCall = WASM_CTZ32;
break;
default: return;
}
+ neededIntrinsics.insert(functionCall);
replaceCurrent(builder->makeCall(functionCall, {curr->value}, curr->type));
}
-
-private:
- std::unique_ptr<Builder> builder;
};
Pass *createRemoveNonJSOpsPass() {