summaryrefslogtreecommitdiff
path: root/src/passes/RemoveNonJSOps.cpp
diff options
context:
space:
mode:
authorAlex Crichton <alex@alexcrichton.com>2018-05-25 10:54:05 -0500
committerAlon Zakai <alonzakai@gmail.com>2018-05-25 08:54:05 -0700
commitbecfa3ee4214bb4e6fbe1dbdbf6c3756c548d82b (patch)
tree1108b2519d6ac1f97ae12b0d32792aa15a4e9237 /src/passes/RemoveNonJSOps.cpp
parent6e56ef195d01936c0c7f7a6d1c0f5d1c7e2c2e52 (diff)
downloadbinaryen-becfa3ee4214bb4e6fbe1dbdbf6c3756c548d82b.tar.gz
binaryen-becfa3ee4214bb4e6fbe1dbdbf6c3756c548d82b.tar.bz2
binaryen-becfa3ee4214bb4e6fbe1dbdbf6c3756c548d82b.zip
wasm2asm: Finish i64 lowering operations (#1563)
* wasm2asm: Finish i64 lowering operations This commit finishes out lowering i64 operations to JS with implementations of division and remainder for JS. The primary change here is to have these compiled from Rust to wasm and then have them "linked in" via intrinsics. The `RemoveNonJSOps` pass has been updated to include some of what `I64ToI32Lowering` was previously doing, basically replacing some instructions with calls to intrinsics. The intrinsics are now all tracked in one location. Hopefully the intrinsics don't need to be regenerated too much, but for posterity the source currently [lives in a gist][gist], although I suspect that gist won't continue to compile and work as-is for all of time. [gist]: https://gist.github.com/alexcrichton/e7ea67bcdd17ce4b6254e66f77165690
Diffstat (limited to 'src/passes/RemoveNonJSOps.cpp')
-rw-r--r--src/passes/RemoveNonJSOps.cpp406
1 files changed, 90 insertions, 316 deletions
diff --git a/src/passes/RemoveNonJSOps.cpp b/src/passes/RemoveNonJSOps.cpp
index 77809d6b3..4bd40a6c6 100644
--- a/src/passes/RemoveNonJSOps.cpp
+++ b/src/passes/RemoveNonJSOps.cpp
@@ -15,341 +15,97 @@
*/
//
-//
// Removes all operations in a wasm module that aren't inherently implementable
-// in JS. This includes things like `f32.nearest` and
-// `f64.copysign`. Most operations are lowered to a call to an injected
+// in JS. This includes things like 64-bit division, `f32.nearest`,
+// `f64.copysign`, etc. Most operations are lowered to a call to an injected
// intrinsic implementation. Intrinsics don't use themselves to implement
// themselves.
//
+// You'll find a large wast blob in `wasm-intrinsics.wast` next to this file
+// which contains all of the injected intrinsics. We manually copy over any
+// needed intrinsics from this module into the module that we're optimizing
+// after walking the current module.
+//
#include <wasm.h>
#include <pass.h>
#include "asmjs/shared-constants.h"
#include "wasm-builder.h"
+#include "wasm-s-parser.h"
+#include "ir/module-utils.h"
+#include "ir/find_all.h"
+#include "passes/intrinsics-module.h"
namespace wasm {
struct RemoveNonJSOpsPass : public WalkerPass<PostWalker<RemoveNonJSOpsPass>> {
- bool needNearestF32 = false;
- bool needNearestF64 = false;
- bool needTruncF32 = false;
- bool needTruncF64 = false;
- bool needCtzInt32 = false;
- bool needPopcntInt32 = false;
- bool needRotLInt32 = false;
- bool needRotRInt32 = false;
+ std::unique_ptr<Builder> builder;
+ std::unordered_set<Name> neededIntrinsics;
bool isFunctionParallel() override { return false; }
Pass* create() override { return new RemoveNonJSOpsPass; }
void doWalkModule(Module* module) {
+ // Discover all of the intrinsics that we need to inject, lowering all
+ // operations to intrinsic calls while we're at it.
if (!builder) builder = make_unique<Builder>(*module);
PostWalker<RemoveNonJSOpsPass>::doWalkModule(module);
- if (needNearestF32) {
- module->addFunction(createNearest(f32));
- }
- if (needNearestF64) {
- module->addFunction(createNearest(f64));
- }
- if (needTruncF32) {
- module->addFunction(createTrunc(f32));
- }
- if (needTruncF64) {
- module->addFunction(createTrunc(f64));
- }
- if (needCtzInt32) {
- module->addFunction(createCtz());
- }
- if (needPopcntInt32) {
- module->addFunction(createPopcnt());
+ if (neededIntrinsics.size() == 0) {
+ return;
}
- if (needRotLInt32) {
- module->addFunction(createRot(RotLInt32));
- }
- if (needRotRInt32) {
- module->addFunction(createRot(RotRInt32));
- }
- }
-
- Function *createNearest(Type f) {
- // fn nearest(f: float) -> float {
- // let ceil = ceil(f);
- // let floor = floor(f);
- // let fract = f - floor;
- // if fract < 0.5 {
- // floor
- // } else if fract > 0.5 {
- // ceil
- // } else {
- // let rem = floor / 2.0;
- // if rem - floor(rem) == 0.0 {
- // floor
- // } else {
- // ceil
- // }
- // }
- // }
- Index arg = 0;
- Index ceil = 1;
- Index floor = 2;
- Index fract = 3;
- Index rem = 4;
-
- UnaryOp ceilOp = CeilFloat32;
- UnaryOp floorOp = FloorFloat32;
- BinaryOp subOp = SubFloat32;
- BinaryOp ltOp = LtFloat32;
- BinaryOp gtOp = GtFloat32;
- BinaryOp divOp = DivFloat32;
- BinaryOp eqOp = EqFloat32;
- Literal litHalf((float) 0.5);
- Literal litOne((float) 1.0);
- Literal litZero((float) 0.0);
- Literal litTwo((float) 2.0);
- if (f == f64) {
- ceilOp = CeilFloat64;
- floorOp = FloorFloat64;
- subOp = SubFloat64;
- ltOp = LtFloat64;
- gtOp = GtFloat64;
- divOp = DivFloat64;
- eqOp = EqFloat64;
- litHalf = Literal((double) 0.5);
- litOne = Literal((double) 1.0);
- litZero = Literal((double) 0.0);
- litTwo = Literal((double) 2.0);
+ // Parse the wast blob we have at the end of this file.
+ //
+ // TODO: only do this once per invocation of wasm2asm
+ Module intrinsicsModule;
+ std::string input(IntrinsicsModuleWast);
+ SExpressionParser parser(const_cast<char*>(input.c_str()));
+ Element& root = *parser.root;
+ SExpressionWasmBuilder builder(intrinsicsModule, *root[0]);
+
+ std::set<Name> neededFunctions;
+
+ // Iteratively link intrinsics from `intrinsicsModule` into our destination
+ // module, as needed.
+ //
+ // Note that intrinsics often use one another. For example the 64-bit
+ // division intrinsic ends up using the 32-bit ctz intrinsic, but does so
+ // via a native instruction. The loop here is used to continuously reprocess
+ // injected intrinsics to ensure that they never contain non-js ops when
+ // we're done.
+ while (neededIntrinsics.size() > 0) {
+ // Recursively probe all needed intrinsics for transitively used
+ // functions. This is building up a set of functions we'll link into our
+ // module.
+ for (auto &name : neededIntrinsics) {
+ addNeededFunctions(intrinsicsModule, name, neededFunctions);
+ }
+ neededIntrinsics.clear();
+
+ // Link in everything that wasn't already linked in. After we've done the
+ // copy we then walk the function to rewrite any non-js operations it has
+ // as well.
+ for (auto &name : neededFunctions) {
+ doWalkFunction(ModuleUtils::copyFunction(intrinsicsModule, *module, name));
+ }
+ neededFunctions.clear();
}
-
- Expression *body = builder->blockify(
- builder->makeSetLocal(
- ceil,
- builder->makeUnary(ceilOp, builder->makeGetLocal(arg, f))
- ),
- builder->makeSetLocal(
- floor,
- builder->makeUnary(floorOp, builder->makeGetLocal(arg, f))
- ),
- builder->makeSetLocal(
- fract,
- builder->makeBinary(
- subOp,
- builder->makeGetLocal(arg, f),
- builder->makeGetLocal(floor, f)
- )
- ),
- builder->makeIf(
- builder->makeBinary(
- ltOp,
- builder->makeGetLocal(fract, f),
- builder->makeConst(litHalf)
- ),
- builder->makeGetLocal(floor, f),
- builder->makeIf(
- builder->makeBinary(
- gtOp,
- builder->makeGetLocal(fract, f),
- builder->makeConst(litHalf)
- ),
- builder->makeGetLocal(ceil, f),
- builder->blockify(
- builder->makeSetLocal(
- rem,
- builder->makeBinary(
- divOp,
- builder->makeGetLocal(floor, f),
- builder->makeConst(litTwo)
- )
- ),
- builder->makeIf(
- builder->makeBinary(
- eqOp,
- builder->makeBinary(
- subOp,
- builder->makeGetLocal(rem, f),
- builder->makeUnary(
- floorOp,
- builder->makeGetLocal(rem, f)
- )
- ),
- builder->makeConst(litZero)
- ),
- builder->makeGetLocal(floor, f),
- builder->makeGetLocal(ceil, f)
- )
- )
- )
- )
- );
- std::vector<Type> params = {f};
- std::vector<Type> vars = {f, f, f, f, f};
- Name name = f == f32 ? WASM_NEAREST_F32 : WASM_NEAREST_F64;
- return builder->makeFunction(name, std::move(params), f, std::move(vars), body);
}
- Function *createTrunc(Type f) {
- // fn trunc(f: float) -> float {
- // if f < 0.0 {
- // ceil(f)
- // } else {
- // floor(f)
- // }
- // }
-
- Index arg = 0;
-
- UnaryOp ceilOp = CeilFloat32;
- UnaryOp floorOp = FloorFloat32;
- BinaryOp ltOp = LtFloat32;
- Literal litZero((float) 0.0);
- if (f == f64) {
- ceilOp = CeilFloat64;
- floorOp = FloorFloat64;
- ltOp = LtFloat64;
- litZero = Literal((double) 0.0);
+ void addNeededFunctions(Module &m, Name name, std::set<Name> &needed) {
+ if (needed.count(name)) {
+ return;
}
+ needed.insert(name);
- Expression *body = builder->makeIf(
- builder->makeBinary(
- ltOp,
- builder->makeGetLocal(arg, f),
- builder->makeConst(litZero)
- ),
- builder->makeUnary(ceilOp, builder->makeGetLocal(arg, f)),
- builder->makeUnary(floorOp, builder->makeGetLocal(arg, f))
- );
- std::vector<Type> params = {f};
- std::vector<Type> vars = {};
- Name name = f == f32 ? WASM_TRUNC_F32 : WASM_TRUNC_F64;
- return builder->makeFunction(name, std::move(params), f, std::move(vars), body);
- }
-
- Function* createCtz() {
- // if eqz(x) then 32 else (32 - clz(x ^ (x - 1)))
- Binary* xorExp = builder->makeBinary(
- XorInt32,
- builder->makeGetLocal(0, i32),
- builder->makeBinary(
- SubInt32,
- builder->makeGetLocal(0, i32),
- builder->makeConst(Literal(int32_t(1)))
- )
- );
- Binary* subExp = builder->makeBinary(
- SubInt32,
- builder->makeConst(Literal(int32_t(32 - 1))),
- builder->makeUnary(ClzInt32, xorExp)
- );
- If* body = builder->makeIf(
- builder->makeUnary(
- EqZInt32,
- builder->makeGetLocal(0, i32)
- ),
- builder->makeConst(Literal(int32_t(32))),
- subExp
- );
- return builder->makeFunction(
- WASM_CTZ32,
- std::vector<NameType>{NameType("x", i32)},
- i32,
- std::vector<NameType>{},
- body
- );
- }
-
- Function* createPopcnt() {
- // popcnt implemented as:
- // int c; for (c = 0; x != 0; c++) { x = x & (x - 1) }; return c
- Name loopName("l");
- Name blockName("b");
- Break* brIf = builder->makeBreak(
- blockName,
- builder->makeGetLocal(1, i32),
- builder->makeUnary(
- EqZInt32,
- builder->makeGetLocal(0, i32)
- )
- );
- SetLocal* update = builder->makeSetLocal(
- 0,
- builder->makeBinary(
- AndInt32,
- builder->makeGetLocal(0, i32),
- builder->makeBinary(
- SubInt32,
- builder->makeGetLocal(0, i32),
- builder->makeConst(Literal(int32_t(1)))
- )
- )
- );
- SetLocal* inc = builder->makeSetLocal(
- 1,
- builder->makeBinary(
- AddInt32,
- builder->makeGetLocal(1, i32),
- builder->makeConst(Literal(1))
- )
- );
- Break* cont = builder->makeBreak(loopName);
- Loop* loop = builder->makeLoop(
- loopName,
- builder->blockify(builder->makeDrop(brIf), update, inc, cont)
- );
- Block* loopBlock = builder->blockifyWithName(loop, blockName);
- // TODO: not sure why this is necessary...
- loopBlock->type = i32;
- SetLocal* initCount = builder->makeSetLocal(1, builder->makeConst(Literal(0)));
- return builder->makeFunction(
- WASM_POPCNT32,
- std::vector<NameType>{NameType("x", i32)},
- i32,
- std::vector<NameType>{NameType("count", i32)},
- builder->blockify(initCount, loopBlock)
- );
- }
-
- Function* createRot(BinaryOp op) {
- // left rotate is:
- // (((((~0) >>> k) & x) << k) | ((((~0) << (w - k)) & x) >>> (w - k)))
- // where k is shift modulo w. reverse shifts for right rotate
- bool isLRot = op == RotLInt32;
- BinaryOp lshift = isLRot ? ShlInt32 : ShrUInt32;
- BinaryOp rshift = isLRot ? ShrUInt32 : ShlInt32;
- Literal widthMask(int32_t(32 - 1));
- Literal width(int32_t(32));
- auto shiftVal = [&]() {
- return builder->makeBinary(
- AndInt32,
- builder->makeGetLocal(1, i32),
- builder->makeConst(widthMask)
- );
- };
- auto widthSub = [&]() {
- return builder->makeBinary(SubInt32, builder->makeConst(width), shiftVal());
- };
- auto fullMask = [&]() {
- return builder->makeConst(Literal(~int32_t(0)));
- };
- Binary* maskRShift = builder->makeBinary(rshift, fullMask(), shiftVal());
- Binary* lowMask = builder->makeBinary(AndInt32, maskRShift, builder->makeGetLocal(0, i32));
- Binary* lowShift = builder->makeBinary(lshift, lowMask, shiftVal());
- Binary* maskLShift = builder->makeBinary(lshift, fullMask(), widthSub());
- Binary* highMask =
- builder->makeBinary(AndInt32, maskLShift, builder->makeGetLocal(0, i32));
- Binary* highShift = builder->makeBinary(rshift, highMask, widthSub());
- Binary* body = builder->makeBinary(OrInt32, lowShift, highShift);
- return builder->makeFunction(
- isLRot ? WASM_ROTL32 : WASM_ROTR32,
- std::vector<NameType>{NameType("x", i32),
- NameType("k", i32)},
- i32,
- std::vector<NameType>{},
- body
- );
+ auto function = m.getFunction(name);
+ FindAll<Call> calls(function->body);
+ for (auto &call : calls.list) {
+ this->addNeededFunctions(m, call->target, needed);
+ }
}
void doWalkFunction(Function* func) {
@@ -366,16 +122,36 @@ struct RemoveNonJSOpsPass : public WalkerPass<PostWalker<RemoveNonJSOpsPass>> {
return;
case RotLInt32:
- needRotLInt32 = true;
name = WASM_ROTL32;
break;
case RotRInt32:
- needRotRInt32 = true;
name = WASM_ROTR32;
break;
+ case RotLInt64:
+ name = WASM_ROTL64;
+ break;
+ case RotRInt64:
+ name = WASM_ROTR64;
+ break;
+ case MulInt64:
+ name = WASM_I64_MUL;
+ break;
+ case DivSInt64:
+ name = WASM_I64_SDIV;
+ break;
+ case DivUInt64:
+ name = WASM_I64_UDIV;
+ break;
+ case RemSInt64:
+ name = WASM_I64_SREM;
+ break;
+ case RemUInt64:
+ name = WASM_I64_UREM;
+ break;
default: return;
}
+ neededIntrinsics.insert(name);
replaceCurrent(builder->makeCall(name, {curr->left, curr->right}, curr->type));
}
@@ -435,40 +211,38 @@ struct RemoveNonJSOpsPass : public WalkerPass<PostWalker<RemoveNonJSOpsPass>> {
Name functionCall;
switch (curr->op) {
case NearestFloat32:
- needNearestF32 = true;
functionCall = WASM_NEAREST_F32;
break;
case NearestFloat64:
- needNearestF64 = true;
functionCall = WASM_NEAREST_F64;
break;
case TruncFloat32:
- needTruncF32 = true;
functionCall = WASM_TRUNC_F32;
break;
case TruncFloat64:
- needTruncF64 = true;
functionCall = WASM_TRUNC_F64;
break;
+ case PopcntInt64:
+ functionCall = WASM_POPCNT64;
+ break;
case PopcntInt32:
- needPopcntInt32 = true;
functionCall = WASM_POPCNT32;
break;
+ case CtzInt64:
+ functionCall = WASM_CTZ64;
+ break;
case CtzInt32:
- needCtzInt32 = true;
functionCall = WASM_CTZ32;
break;
default: return;
}
+ neededIntrinsics.insert(functionCall);
replaceCurrent(builder->makeCall(functionCall, {curr->value}, curr->type));
}
-
-private:
- std::unique_ptr<Builder> builder;
};
Pass *createRemoveNonJSOpsPass() {