summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlon Zakai <alonzakai@gmail.com>2018-04-13 09:17:21 -0700
committerGitHub <noreply@github.com>2018-04-13 09:17:21 -0700
commitfd3b3e54bd97abbf8269b33d937ad2f44ba4bb60 (patch)
treef36ce91f19944f00460035105758f61606a42bdd
parent7a8273ae2c1854b9840fc56a952e572f673bb10f (diff)
downloadbinaryen-fd3b3e54bd97abbf8269b33d937ad2f44ba4bb60.tar.gz
binaryen-fd3b3e54bd97abbf8269b33d937ad2f44ba4bb60.tar.bz2
binaryen-fd3b3e54bd97abbf8269b33d937ad2f44ba4bb60.zip
Refactor interpreter (#1508)
* Move more logic to the Literal class. We now leave all the work to there, except for handling traps. * Avoid switching on the type, then the opcode, then Literal method usually switches on the type again - instead, do one big switch for the opcodes (then the Literal method is unchanged) which is shorter and clearer, and avoids that first switching.
-rw-r--r--src/literal.h11
-rw-r--r--src/wasm-interpreter.h352
-rw-r--r--src/wasm/literal.cpp56
3 files changed, 220 insertions, 199 deletions
diff --git a/src/literal.h b/src/literal.h
index 72bd36019..ae628149e 100644
--- a/src/literal.h
+++ b/src/literal.h
@@ -38,12 +38,6 @@ private:
int64_t i64;
};
- // The RHS of shl/shru/shrs must be masked by bitwidth.
- template <typename T>
- static T shiftMask(T val) {
- return val & (sizeof(T) * 8 - 1);
- }
-
public:
Literal() : type(Type::none), i64(0) {}
explicit Literal(Type type) : type(type), i64(0) {}
@@ -98,6 +92,9 @@ public:
Literal extendToSI64() const;
Literal extendToUI64() const;
Literal extendToF64() const;
+ Literal extendS8() const;
+ Literal extendS16() const;
+ Literal extendS32() const;
Literal truncateToI32() const;
Literal truncateToF32() const;
@@ -106,6 +103,7 @@ public:
Literal convertSToF64() const;
Literal convertUToF64() const;
+ Literal eqz() const;
Literal neg() const;
Literal abs() const;
Literal ceil() const;
@@ -113,6 +111,7 @@ public:
Literal trunc() const;
Literal nearbyint() const;
Literal sqrt() const;
+ Literal demote() const;
Literal add(const Literal& other) const;
Literal sub(const Literal& other) const;
diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h
index 73cfc393d..d1aab7ab6 100644
--- a/src/wasm-interpreter.h
+++ b/src/wasm-interpreter.h
@@ -32,7 +32,6 @@
#include "wasm.h"
#include "wasm-traversal.h"
-
#ifdef WASM_INTERPRETER_DEBUG
#include "wasm-printing.h"
#endif
@@ -231,98 +230,73 @@ public:
NOTE_EVAL1(curr->value);
return Flow(curr->value); // heh
}
+
+ // Unary and Binary nodes, the core math computations. We mostly just
+ // delegate to the Literal::* methods, except we handle traps here.
+
Flow visitUnary(Unary *curr) {
NOTE_ENTER("Unary");
Flow flow = visit(curr->value);
if (flow.breaking()) return flow;
Literal value = flow.value;
NOTE_EVAL1(value);
- if (value.type == i32) {
- switch (curr->op) {
- case ClzInt32: return value.countLeadingZeroes();
- case CtzInt32: return value.countTrailingZeroes();
- case PopcntInt32: return value.popCount();
- case EqZInt32: return Literal(int32_t(value == Literal(int32_t(0))));
- case ReinterpretInt32: return value.castToF32();
- case ExtendSInt32: return value.extendToSI64();
- case ExtendUInt32: return value.extendToUI64();
- case ConvertUInt32ToFloat32: return value.convertUToF32();
- case ConvertUInt32ToFloat64: return value.convertUToF64();
- case ConvertSInt32ToFloat32: return value.convertSToF32();
- case ConvertSInt32ToFloat64: return value.convertSToF64();
- case ExtendS8Int32: return Literal(int32_t(int8_t(value.geti32() & 0xFF)));
- case ExtendS16Int32: return Literal(int32_t(int16_t(value.geti32() & 0xFFFF)));
- default: WASM_UNREACHABLE();
- }
+ switch (curr->op) {
+ case ClzInt32:
+ case ClzInt64: return value.countLeadingZeroes();
+ case CtzInt32:
+ case CtzInt64: return value.countTrailingZeroes();
+ case PopcntInt32:
+ case PopcntInt64: return value.popCount();
+ case EqZInt32:
+ case EqZInt64: return value.eqz();
+ case ReinterpretInt32: return value.castToF32();
+ case ReinterpretInt64: return value.castToF64();
+ case ExtendSInt32: return value.extendToSI64();
+ case ExtendUInt32: return value.extendToUI64();
+ case WrapInt64: return value.truncateToI32();
+ case ConvertUInt32ToFloat32:
+ case ConvertUInt64ToFloat32: return value.convertUToF32();
+ case ConvertUInt32ToFloat64:
+ case ConvertUInt64ToFloat64: return value.convertUToF64();
+ case ConvertSInt32ToFloat32:
+ case ConvertSInt64ToFloat32: return value.convertSToF32();
+ case ConvertSInt32ToFloat64:
+ case ConvertSInt64ToFloat64: return value.convertSToF64();
+ case ExtendS8Int32:
+ case ExtendS8Int64: return value.extendS8();
+ case ExtendS16Int32:
+ case ExtendS16Int64: return value.extendS16();
+ case ExtendS32Int64: return value.extendS32();
+
+ case NegFloat32:
+ case NegFloat64: return value.neg();
+ case AbsFloat32:
+ case AbsFloat64: return value.abs();
+ case CeilFloat32:
+ case CeilFloat64: return value.ceil();
+ case FloorFloat32:
+ case FloorFloat64: return value.floor();
+ case TruncFloat32:
+ case TruncFloat64: return value.trunc();
+ case NearestFloat32:
+ case NearestFloat64: return value.nearbyint();
+ case SqrtFloat32:
+ case SqrtFloat64: return value.sqrt();
+ case TruncSFloat32ToInt32:
+ case TruncSFloat64ToInt32:
+ case TruncSFloat32ToInt64:
+ case TruncSFloat64ToInt64: return truncSFloat(curr, value);
+ case TruncUFloat32ToInt32:
+ case TruncUFloat64ToInt32:
+ case TruncUFloat32ToInt64:
+ case TruncUFloat64ToInt64: return truncUFloat(curr, value);
+ case ReinterpretFloat32: return value.castToI32();
+ case PromoteFloat32: return value.extendToF64();
+ case ReinterpretFloat64: return value.castToI64();
+ case DemoteFloat64: return value.demote();
+
+ default: WASM_UNREACHABLE();
}
- if (value.type == i64) {
- switch (curr->op) {
- case ClzInt64: return value.countLeadingZeroes();
- case CtzInt64: return value.countTrailingZeroes();
- case PopcntInt64: return value.popCount();
- case EqZInt64: return Literal(int32_t(value == Literal(int64_t(0))));
- case WrapInt64: return value.truncateToI32();
- case ReinterpretInt64: return value.castToF64();
- case ConvertUInt64ToFloat32: return value.convertUToF32();
- case ConvertUInt64ToFloat64: return value.convertUToF64();
- case ConvertSInt64ToFloat32: return value.convertSToF32();
- case ConvertSInt64ToFloat64: return value.convertSToF64();
- case ExtendS8Int64: return Literal(int64_t(int8_t(value.geti64() & 0xFF)));
- case ExtendS16Int64: return Literal(int64_t(int16_t(value.geti64() & 0xFFFF)));
- case ExtendS32Int64: return Literal(int64_t(int32_t(value.geti64() & 0xFFFFFFFF)));
- default: WASM_UNREACHABLE();
- }
- }
- if (value.type == f32) {
- switch (curr->op) {
- case NegFloat32: return value.neg();
- case AbsFloat32: return value.abs();
- case CeilFloat32: return value.ceil();
- case FloorFloat32: return value.floor();
- case TruncFloat32: return value.trunc();
- case NearestFloat32: return value.nearbyint();
- case SqrtFloat32: return value.sqrt();
- case TruncSFloat32ToInt32:
- case TruncSFloat32ToInt64: return truncSFloat(curr, value);
- case TruncUFloat32ToInt32:
- case TruncUFloat32ToInt64: return truncUFloat(curr, value);
- case ReinterpretFloat32: return value.castToI32();
- case PromoteFloat32: return value.extendToF64();
- default: WASM_UNREACHABLE();
- }
- }
- if (value.type == f64) {
- switch (curr->op) {
- case NegFloat64: return value.neg();
- case AbsFloat64: return value.abs();
- case CeilFloat64: return value.ceil();
- case FloorFloat64: return value.floor();
- case TruncFloat64: return value.trunc();
- case NearestFloat64: return value.nearbyint();
- case SqrtFloat64: return value.sqrt();
- case TruncSFloat64ToInt32:
- case TruncSFloat64ToInt64: return truncSFloat(curr, value);
- case TruncUFloat64ToInt32:
- case TruncUFloat64ToInt64: return truncUFloat(curr, value);
- case ReinterpretFloat64: return value.castToI64();
- case DemoteFloat64: {
- double val = value.getFloat();
- if (std::isnan(val)) return Literal(float(val));
- if (std::isinf(val)) return Literal(float(val));
- // when close to the limit, but still truncatable to a valid value, do that
- // see https://github.com/WebAssembly/sexpr-wasm-prototype/blob/2d375e8d502327e814d62a08f22da9d9b6b675dc/src/wasm-interpreter.c#L247
- uint64_t bits = value.reinterpreti64();
- if (bits > 0x47efffffe0000000ULL && bits < 0x47effffff0000000ULL) return Literal(std::numeric_limits<float>::max());
- if (bits > 0xc7efffffe0000000ULL && bits < 0xc7effffff0000000ULL) return Literal(-std::numeric_limits<float>::max());
- // when we must convert to infinity, do that
- if (val < -std::numeric_limits<float>::max()) return Literal(-std::numeric_limits<float>::infinity());
- if (val > std::numeric_limits<float>::max()) return Literal(std::numeric_limits<float>::infinity());
- return value.truncateToF32();
- }
- default: WASM_UNREACHABLE();
- }
- }
- WASM_UNREACHABLE();
}
Flow visitBinary(Binary *curr) {
NOTE_ENTER("Binary");
@@ -335,111 +309,115 @@ public:
NOTE_EVAL2(left, right);
assert(isConcreteType(curr->left->type) ? left.type == curr->left->type : true);
assert(isConcreteType(curr->right->type) ? right.type == curr->right->type : true);
- if (left.type == i32) {
- switch (curr->op) {
- case AddInt32: return left.add(right);
- case SubInt32: return left.sub(right);
- case MulInt32: return left.mul(right);
- case DivSInt32: {
- if (right.getInteger() == 0) trap("i32.div_s by 0");
- if (left.getInteger() == std::numeric_limits<int32_t>::min() && right.getInteger() == -1) trap("i32.div_s overflow"); // signed division overflow
- return left.divS(right);
- }
- case DivUInt32: {
- if (right.getInteger() == 0) trap("i32.div_u by 0");
- return left.divU(right);
- }
- case RemSInt32: {
- if (right.getInteger() == 0) trap("i32.rem_s by 0");
- if (left.getInteger() == std::numeric_limits<int32_t>::min() && right.getInteger() == -1) return Literal(int32_t(0));
- return left.remS(right);
- }
- case RemUInt32: {
- if (right.getInteger() == 0) trap("i32.rem_u by 0");
- return left.remU(right);
- }
- case AndInt32: return left.and_(right);
- case OrInt32: return left.or_(right);
- case XorInt32: return left.xor_(right);
- case ShlInt32: return left.shl(right.and_(Literal(int32_t(31))));
- case ShrUInt32: return left.shrU(right.and_(Literal(int32_t(31))));
- case ShrSInt32: return left.shrS(right.and_(Literal(int32_t(31))));
- case RotLInt32: return left.rotL(right);
- case RotRInt32: return left.rotR(right);
- case EqInt32: return left.eq(right);
- case NeInt32: return left.ne(right);
- case LtSInt32: return left.ltS(right);
- case LtUInt32: return left.ltU(right);
- case LeSInt32: return left.leS(right);
- case LeUInt32: return left.leU(right);
- case GtSInt32: return left.gtS(right);
- case GtUInt32: return left.gtU(right);
- case GeSInt32: return left.geS(right);
- case GeUInt32: return left.geU(right);
- default: WASM_UNREACHABLE();
+ switch (curr->op) {
+ case AddInt32:
+ case AddInt64:
+ case AddFloat32:
+ case AddFloat64: return left.add(right);
+ case SubInt32:
+ case SubInt64:
+ case SubFloat32:
+ case SubFloat64: return left.sub(right);
+ case MulInt32:
+ case MulInt64:
+ case MulFloat32:
+ case MulFloat64: return left.mul(right);
+ case DivSInt32: {
+ if (right.getInteger() == 0) trap("i32.div_s by 0");
+ if (left.getInteger() == std::numeric_limits<int32_t>::min() && right.getInteger() == -1) trap("i32.div_s overflow"); // signed division overflow
+ return left.divS(right);
}
- } else if (left.type == i64) {
- switch (curr->op) {
- case AddInt64: return left.add(right);
- case SubInt64: return left.sub(right);
- case MulInt64: return left.mul(right);
- case DivSInt64: {
- if (right.getInteger() == 0) trap("i64.div_s by 0");
- if (left.getInteger() == LLONG_MIN && right.getInteger() == -1LL) trap("i64.div_s overflow"); // signed division overflow
- return left.divS(right);
- }
- case DivUInt64: {
- if (right.getInteger() == 0) trap("i64.div_u by 0");
- return left.divU(right);
- }
- case RemSInt64: {
- if (right.getInteger() == 0) trap("i64.rem_s by 0");
- if (left.getInteger() == LLONG_MIN && right.getInteger() == -1LL) return Literal(int64_t(0));
- return left.remS(right);
- }
- case RemUInt64: {
- if (right.getInteger() == 0) trap("i64.rem_u by 0");
- return left.remU(right);
- }
- case AndInt64: return left.and_(right);
- case OrInt64: return left.or_(right);
- case XorInt64: return left.xor_(right);
- case ShlInt64: return left.shl(right.and_(Literal(int64_t(63))));
- case ShrUInt64: return left.shrU(right.and_(Literal(int64_t(63))));
- case ShrSInt64: return left.shrS(right.and_(Literal(int64_t(63))));
- case RotLInt64: return left.rotL(right);
- case RotRInt64: return left.rotR(right);
- case EqInt64: return left.eq(right);
- case NeInt64: return left.ne(right);
- case LtSInt64: return left.ltS(right);
- case LtUInt64: return left.ltU(right);
- case LeSInt64: return left.leS(right);
- case LeUInt64: return left.leU(right);
- case GtSInt64: return left.gtS(right);
- case GtUInt64: return left.gtU(right);
- case GeSInt64: return left.geS(right);
- case GeUInt64: return left.geU(right);
- default: WASM_UNREACHABLE();
+ case DivUInt32: {
+ if (right.getInteger() == 0) trap("i32.div_u by 0");
+ return left.divU(right);
}
- } else if (left.type == f32 || left.type == f64) {
- switch (curr->op) {
- case AddFloat32: case AddFloat64: return left.add(right);
- case SubFloat32: case SubFloat64: return left.sub(right);
- case MulFloat32: case MulFloat64: return left.mul(right);
- case DivFloat32: case DivFloat64: return left.div(right);
- case CopySignFloat32: case CopySignFloat64: return left.copysign(right);
- case MinFloat32: case MinFloat64: return left.min(right);
- case MaxFloat32: case MaxFloat64: return left.max(right);
- case EqFloat32: case EqFloat64: return left.eq(right);
- case NeFloat32: case NeFloat64: return left.ne(right);
- case LtFloat32: case LtFloat64: return left.lt(right);
- case LeFloat32: case LeFloat64: return left.le(right);
- case GtFloat32: case GtFloat64: return left.gt(right);
- case GeFloat32: case GeFloat64: return left.ge(right);
- default: WASM_UNREACHABLE();
+ case RemSInt32: {
+ if (right.getInteger() == 0) trap("i32.rem_s by 0");
+ if (left.getInteger() == std::numeric_limits<int32_t>::min() && right.getInteger() == -1) return Literal(int32_t(0));
+ return left.remS(right);
+ }
+ case RemUInt32: {
+ if (right.getInteger() == 0) trap("i32.rem_u by 0");
+ return left.remU(right);
+ }
+ case DivSInt64: {
+ if (right.getInteger() == 0) trap("i64.div_s by 0");
+ if (left.getInteger() == LLONG_MIN && right.getInteger() == -1LL) trap("i64.div_s overflow"); // signed division overflow
+ return left.divS(right);
}
+ case DivUInt64: {
+ if (right.getInteger() == 0) trap("i64.div_u by 0");
+ return left.divU(right);
+ }
+ case RemSInt64: {
+ if (right.getInteger() == 0) trap("i64.rem_s by 0");
+ if (left.getInteger() == LLONG_MIN && right.getInteger() == -1LL) return Literal(int64_t(0));
+ return left.remS(right);
+ }
+ case RemUInt64: {
+ if (right.getInteger() == 0) trap("i64.rem_u by 0");
+ return left.remU(right);
+ }
+ case DivFloat32:
+ case DivFloat64: return left.div(right);
+ case AndInt32:
+ case AndInt64: return left.and_(right);
+ case OrInt32:
+ case OrInt64: return left.or_(right);
+ case XorInt32:
+ case XorInt64: return left.xor_(right);
+ case ShlInt32:
+ case ShlInt64: return left.shl(right);
+ case ShrUInt32:
+ case ShrUInt64: return left.shrU(right);
+ case ShrSInt32:
+ case ShrSInt64: return left.shrS(right);
+ case RotLInt32:
+ case RotLInt64: return left.rotL(right);
+ case RotRInt32:
+ case RotRInt64: return left.rotR(right);
+
+ case EqInt32:
+ case EqInt64:
+ case EqFloat32:
+ case EqFloat64: return left.eq(right);
+ case NeInt32:
+ case NeInt64:
+ case NeFloat32:
+ case NeFloat64: return left.ne(right);
+ case LtSInt32:
+ case LtSInt64: return left.ltS(right);
+ case LtUInt32:
+ case LtUInt64: return left.ltU(right);
+ case LeSInt32:
+ case LeSInt64: return left.leS(right);
+ case LeUInt32:
+ case LeUInt64: return left.leU(right);
+ case GtSInt32:
+ case GtSInt64: return left.gtS(right);
+ case GtUInt32:
+ case GtUInt64: return left.gtU(right);
+ case GeSInt32:
+ case GeSInt64: return left.geS(right);
+ case GeUInt32:
+ case GeUInt64: return left.geU(right);
+ case LtFloat32:
+ case LtFloat64: return left.lt(right);
+ case LeFloat32:
+ case LeFloat64: return left.le(right);
+ case GtFloat32:
+ case GtFloat64: return left.gt(right);
+ case GeFloat32:
+ case GeFloat64: return left.ge(right);
+
+ case CopySignFloat32:
+ case CopySignFloat64: return left.copysign(right);
+ case MinFloat32:
+ case MinFloat64: return left.min(right);
+ case MaxFloat32:
+ case MaxFloat64: return left.max(right);
+ default: WASM_UNREACHABLE();
}
- WASM_UNREACHABLE();
}
Flow visitSelect(Select *curr) {
NOTE_ENTER("Select");
diff --git a/src/wasm/literal.cpp b/src/wasm/literal.cpp
index a6dcd17f0..e68dd5d5f 100644
--- a/src/wasm/literal.cpp
+++ b/src/wasm/literal.cpp
@@ -22,6 +22,8 @@
#include "emscripten-optimizer/simple_ast.h"
#include "pretty_printing.h"
#include "support/bits.h"
+#include "ir/bits.h"
+
namespace wasm {
@@ -213,6 +215,23 @@ Literal Literal::extendToF64() const {
return Literal(double(getf32()));
}
+Literal Literal::extendS8() const {
+ if (type == Type::i32) return Literal(int32_t(int8_t(geti32() & 0xFF)));
+ if (type == Type::i64) return Literal(int64_t(int8_t(geti64() & 0xFF)));
+ WASM_UNREACHABLE();
+}
+
+Literal Literal::extendS16() const {
+ if (type == Type::i32) return Literal(int32_t(int16_t(geti32() & 0xFFFF)));
+ if (type == Type::i64) return Literal(int64_t(int16_t(geti64() & 0xFFFF)));
+ WASM_UNREACHABLE();
+}
+
+Literal Literal::extendS32() const {
+ if (type == Type::i64) return Literal(int64_t(int32_t(geti64() & 0xFFFFFFFF)));
+ WASM_UNREACHABLE();
+}
+
Literal Literal::truncateToI32() const {
assert(type == Type::i64);
return Literal((int32_t)i64);
@@ -247,6 +266,16 @@ Literal Literal::convertUToF64() const {
WASM_UNREACHABLE();
}
+Literal Literal::eqz() const {
+ switch (type) {
+ case Type::i32: return eq(Literal(int32_t(0)));
+ case Type::i64: return eq(Literal(int64_t(0)));
+ case Type::f32: return eq(Literal(float(0)));
+ case Type::f64: return eq(Literal(double(0)));
+ default: WASM_UNREACHABLE();
+ }
+}
+
Literal Literal::neg() const {
switch (type) {
case Type::i32: return Literal(-uint32_t(i32));
@@ -307,6 +336,21 @@ Literal Literal::sqrt() const {
}
}
+Literal Literal::demote() const {
+ auto f64 = getf64();
+ if (std::isnan(f64)) return Literal(float(f64));
+ if (std::isinf(f64)) return Literal(float(f64));
+ // when close to the limit, but still truncatable to a valid value, do that
+ // see https://github.com/WebAssembly/sexpr-wasm-prototype/blob/2d375e8d502327e814d62a08f22da9d9b6b675dc/src/wasm-interpreter.c#L247
+ uint64_t bits = reinterpreti64();
+ if (bits > 0x47efffffe0000000ULL && bits < 0x47effffff0000000ULL) return Literal(std::numeric_limits<float>::max());
+ if (bits > 0xc7efffffe0000000ULL && bits < 0xc7effffff0000000ULL) return Literal(-std::numeric_limits<float>::max());
+ // when we must convert to infinity, do that
+ if (f64 < -std::numeric_limits<float>::max()) return Literal(-std::numeric_limits<float>::infinity());
+ if (f64 > std::numeric_limits<float>::max()) return Literal(std::numeric_limits<float>::infinity());
+ return truncateToF32();
+}
+
Literal Literal::add(const Literal& other) const {
switch (type) {
case Type::i32: return Literal(uint32_t(i32) + uint32_t(other.i32));
@@ -441,24 +485,24 @@ Literal Literal::xor_(const Literal& other) const {
Literal Literal::shl(const Literal& other) const {
switch (type) {
- case Type::i32: return Literal(uint32_t(i32) << shiftMask(other.i32));
- case Type::i64: return Literal(uint64_t(i64) << shiftMask(other.i64));
+ case Type::i32: return Literal(uint32_t(i32) << Bits::getEffectiveShifts(other.i32, Type::i32));
+ case Type::i64: return Literal(uint64_t(i64) << Bits::getEffectiveShifts(other.i64, Type::i64));
default: WASM_UNREACHABLE();
}
}
Literal Literal::shrS(const Literal& other) const {
switch (type) {
- case Type::i32: return Literal(i32 >> shiftMask(other.i32));
- case Type::i64: return Literal(i64 >> shiftMask(other.i64));
+ case Type::i32: return Literal(i32 >> Bits::getEffectiveShifts(other.i32, Type::i32));
+ case Type::i64: return Literal(i64 >> Bits::getEffectiveShifts(other.i64, Type::i64));
default: WASM_UNREACHABLE();
}
}
Literal Literal::shrU(const Literal& other) const {
switch (type) {
- case Type::i32: return Literal(uint32_t(i32) >> shiftMask(other.i32));
- case Type::i64: return Literal(uint64_t(i64) >> shiftMask(other.i64));
+ case Type::i32: return Literal(uint32_t(i32) >> Bits::getEffectiveShifts(other.i32, Type::i32));
+ case Type::i64: return Literal(uint64_t(i64) >> Bits::getEffectiveShifts(other.i64, Type::i64));
default: WASM_UNREACHABLE();
}
}