diff options
-rw-r--r-- | src/literal.h | 11 | ||||
-rw-r--r-- | src/wasm-interpreter.h | 352 | ||||
-rw-r--r-- | src/wasm/literal.cpp | 56 |
3 files changed, 220 insertions, 199 deletions
diff --git a/src/literal.h b/src/literal.h index 72bd36019..ae628149e 100644 --- a/src/literal.h +++ b/src/literal.h @@ -38,12 +38,6 @@ private: int64_t i64; }; - // The RHS of shl/shru/shrs must be masked by bitwidth. - template <typename T> - static T shiftMask(T val) { - return val & (sizeof(T) * 8 - 1); - } - public: Literal() : type(Type::none), i64(0) {} explicit Literal(Type type) : type(type), i64(0) {} @@ -98,6 +92,9 @@ public: Literal extendToSI64() const; Literal extendToUI64() const; Literal extendToF64() const; + Literal extendS8() const; + Literal extendS16() const; + Literal extendS32() const; Literal truncateToI32() const; Literal truncateToF32() const; @@ -106,6 +103,7 @@ public: Literal convertSToF64() const; Literal convertUToF64() const; + Literal eqz() const; Literal neg() const; Literal abs() const; Literal ceil() const; @@ -113,6 +111,7 @@ public: Literal trunc() const; Literal nearbyint() const; Literal sqrt() const; + Literal demote() const; Literal add(const Literal& other) const; Literal sub(const Literal& other) const; diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index 73cfc393d..d1aab7ab6 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -32,7 +32,6 @@ #include "wasm.h" #include "wasm-traversal.h" - #ifdef WASM_INTERPRETER_DEBUG #include "wasm-printing.h" #endif @@ -231,98 +230,73 @@ public: NOTE_EVAL1(curr->value); return Flow(curr->value); // heh } + + // Unary and Binary nodes, the core math computations. We mostly just + // delegate to the Literal::* methods, except we handle traps here. + Flow visitUnary(Unary *curr) { NOTE_ENTER("Unary"); Flow flow = visit(curr->value); if (flow.breaking()) return flow; Literal value = flow.value; NOTE_EVAL1(value); - if (value.type == i32) { - switch (curr->op) { - case ClzInt32: return value.countLeadingZeroes(); - case CtzInt32: return value.countTrailingZeroes(); - case PopcntInt32: return value.popCount(); - case EqZInt32: return Literal(int32_t(value == Literal(int32_t(0)))); - case ReinterpretInt32: return value.castToF32(); - case ExtendSInt32: return value.extendToSI64(); - case ExtendUInt32: return value.extendToUI64(); - case ConvertUInt32ToFloat32: return value.convertUToF32(); - case ConvertUInt32ToFloat64: return value.convertUToF64(); - case ConvertSInt32ToFloat32: return value.convertSToF32(); - case ConvertSInt32ToFloat64: return value.convertSToF64(); - case ExtendS8Int32: return Literal(int32_t(int8_t(value.geti32() & 0xFF))); - case ExtendS16Int32: return Literal(int32_t(int16_t(value.geti32() & 0xFFFF))); - default: WASM_UNREACHABLE(); - } + switch (curr->op) { + case ClzInt32: + case ClzInt64: return value.countLeadingZeroes(); + case CtzInt32: + case CtzInt64: return value.countTrailingZeroes(); + case PopcntInt32: + case PopcntInt64: return value.popCount(); + case EqZInt32: + case EqZInt64: return value.eqz(); + case ReinterpretInt32: return value.castToF32(); + case ReinterpretInt64: return value.castToF64(); + case ExtendSInt32: return value.extendToSI64(); + case ExtendUInt32: return value.extendToUI64(); + case WrapInt64: return value.truncateToI32(); + case ConvertUInt32ToFloat32: + case ConvertUInt64ToFloat32: return value.convertUToF32(); + case ConvertUInt32ToFloat64: + case ConvertUInt64ToFloat64: return value.convertUToF64(); + case ConvertSInt32ToFloat32: + case ConvertSInt64ToFloat32: return value.convertSToF32(); + case ConvertSInt32ToFloat64: + case ConvertSInt64ToFloat64: return value.convertSToF64(); + case ExtendS8Int32: + case ExtendS8Int64: return value.extendS8(); + case ExtendS16Int32: + case ExtendS16Int64: return value.extendS16(); + case ExtendS32Int64: return value.extendS32(); + + case NegFloat32: + case NegFloat64: return value.neg(); + case AbsFloat32: + case AbsFloat64: return value.abs(); + case CeilFloat32: + case CeilFloat64: return value.ceil(); + case FloorFloat32: + case FloorFloat64: return value.floor(); + case TruncFloat32: + case TruncFloat64: return value.trunc(); + case NearestFloat32: + case NearestFloat64: return value.nearbyint(); + case SqrtFloat32: + case SqrtFloat64: return value.sqrt(); + case TruncSFloat32ToInt32: + case TruncSFloat64ToInt32: + case TruncSFloat32ToInt64: + case TruncSFloat64ToInt64: return truncSFloat(curr, value); + case TruncUFloat32ToInt32: + case TruncUFloat64ToInt32: + case TruncUFloat32ToInt64: + case TruncUFloat64ToInt64: return truncUFloat(curr, value); + case ReinterpretFloat32: return value.castToI32(); + case PromoteFloat32: return value.extendToF64(); + case ReinterpretFloat64: return value.castToI64(); + case DemoteFloat64: return value.demote(); + + default: WASM_UNREACHABLE(); } - if (value.type == i64) { - switch (curr->op) { - case ClzInt64: return value.countLeadingZeroes(); - case CtzInt64: return value.countTrailingZeroes(); - case PopcntInt64: return value.popCount(); - case EqZInt64: return Literal(int32_t(value == Literal(int64_t(0)))); - case WrapInt64: return value.truncateToI32(); - case ReinterpretInt64: return value.castToF64(); - case ConvertUInt64ToFloat32: return value.convertUToF32(); - case ConvertUInt64ToFloat64: return value.convertUToF64(); - case ConvertSInt64ToFloat32: return value.convertSToF32(); - case ConvertSInt64ToFloat64: return value.convertSToF64(); - case ExtendS8Int64: return Literal(int64_t(int8_t(value.geti64() & 0xFF))); - case ExtendS16Int64: return Literal(int64_t(int16_t(value.geti64() & 0xFFFF))); - case ExtendS32Int64: return Literal(int64_t(int32_t(value.geti64() & 0xFFFFFFFF))); - default: WASM_UNREACHABLE(); - } - } - if (value.type == f32) { - switch (curr->op) { - case NegFloat32: return value.neg(); - case AbsFloat32: return value.abs(); - case CeilFloat32: return value.ceil(); - case FloorFloat32: return value.floor(); - case TruncFloat32: return value.trunc(); - case NearestFloat32: return value.nearbyint(); - case SqrtFloat32: return value.sqrt(); - case TruncSFloat32ToInt32: - case TruncSFloat32ToInt64: return truncSFloat(curr, value); - case TruncUFloat32ToInt32: - case TruncUFloat32ToInt64: return truncUFloat(curr, value); - case ReinterpretFloat32: return value.castToI32(); - case PromoteFloat32: return value.extendToF64(); - default: WASM_UNREACHABLE(); - } - } - if (value.type == f64) { - switch (curr->op) { - case NegFloat64: return value.neg(); - case AbsFloat64: return value.abs(); - case CeilFloat64: return value.ceil(); - case FloorFloat64: return value.floor(); - case TruncFloat64: return value.trunc(); - case NearestFloat64: return value.nearbyint(); - case SqrtFloat64: return value.sqrt(); - case TruncSFloat64ToInt32: - case TruncSFloat64ToInt64: return truncSFloat(curr, value); - case TruncUFloat64ToInt32: - case TruncUFloat64ToInt64: return truncUFloat(curr, value); - case ReinterpretFloat64: return value.castToI64(); - case DemoteFloat64: { - double val = value.getFloat(); - if (std::isnan(val)) return Literal(float(val)); - if (std::isinf(val)) return Literal(float(val)); - // when close to the limit, but still truncatable to a valid value, do that - // see https://github.com/WebAssembly/sexpr-wasm-prototype/blob/2d375e8d502327e814d62a08f22da9d9b6b675dc/src/wasm-interpreter.c#L247 - uint64_t bits = value.reinterpreti64(); - if (bits > 0x47efffffe0000000ULL && bits < 0x47effffff0000000ULL) return Literal(std::numeric_limits<float>::max()); - if (bits > 0xc7efffffe0000000ULL && bits < 0xc7effffff0000000ULL) return Literal(-std::numeric_limits<float>::max()); - // when we must convert to infinity, do that - if (val < -std::numeric_limits<float>::max()) return Literal(-std::numeric_limits<float>::infinity()); - if (val > std::numeric_limits<float>::max()) return Literal(std::numeric_limits<float>::infinity()); - return value.truncateToF32(); - } - default: WASM_UNREACHABLE(); - } - } - WASM_UNREACHABLE(); } Flow visitBinary(Binary *curr) { NOTE_ENTER("Binary"); @@ -335,111 +309,115 @@ public: NOTE_EVAL2(left, right); assert(isConcreteType(curr->left->type) ? left.type == curr->left->type : true); assert(isConcreteType(curr->right->type) ? right.type == curr->right->type : true); - if (left.type == i32) { - switch (curr->op) { - case AddInt32: return left.add(right); - case SubInt32: return left.sub(right); - case MulInt32: return left.mul(right); - case DivSInt32: { - if (right.getInteger() == 0) trap("i32.div_s by 0"); - if (left.getInteger() == std::numeric_limits<int32_t>::min() && right.getInteger() == -1) trap("i32.div_s overflow"); // signed division overflow - return left.divS(right); - } - case DivUInt32: { - if (right.getInteger() == 0) trap("i32.div_u by 0"); - return left.divU(right); - } - case RemSInt32: { - if (right.getInteger() == 0) trap("i32.rem_s by 0"); - if (left.getInteger() == std::numeric_limits<int32_t>::min() && right.getInteger() == -1) return Literal(int32_t(0)); - return left.remS(right); - } - case RemUInt32: { - if (right.getInteger() == 0) trap("i32.rem_u by 0"); - return left.remU(right); - } - case AndInt32: return left.and_(right); - case OrInt32: return left.or_(right); - case XorInt32: return left.xor_(right); - case ShlInt32: return left.shl(right.and_(Literal(int32_t(31)))); - case ShrUInt32: return left.shrU(right.and_(Literal(int32_t(31)))); - case ShrSInt32: return left.shrS(right.and_(Literal(int32_t(31)))); - case RotLInt32: return left.rotL(right); - case RotRInt32: return left.rotR(right); - case EqInt32: return left.eq(right); - case NeInt32: return left.ne(right); - case LtSInt32: return left.ltS(right); - case LtUInt32: return left.ltU(right); - case LeSInt32: return left.leS(right); - case LeUInt32: return left.leU(right); - case GtSInt32: return left.gtS(right); - case GtUInt32: return left.gtU(right); - case GeSInt32: return left.geS(right); - case GeUInt32: return left.geU(right); - default: WASM_UNREACHABLE(); + switch (curr->op) { + case AddInt32: + case AddInt64: + case AddFloat32: + case AddFloat64: return left.add(right); + case SubInt32: + case SubInt64: + case SubFloat32: + case SubFloat64: return left.sub(right); + case MulInt32: + case MulInt64: + case MulFloat32: + case MulFloat64: return left.mul(right); + case DivSInt32: { + if (right.getInteger() == 0) trap("i32.div_s by 0"); + if (left.getInteger() == std::numeric_limits<int32_t>::min() && right.getInteger() == -1) trap("i32.div_s overflow"); // signed division overflow + return left.divS(right); } - } else if (left.type == i64) { - switch (curr->op) { - case AddInt64: return left.add(right); - case SubInt64: return left.sub(right); - case MulInt64: return left.mul(right); - case DivSInt64: { - if (right.getInteger() == 0) trap("i64.div_s by 0"); - if (left.getInteger() == LLONG_MIN && right.getInteger() == -1LL) trap("i64.div_s overflow"); // signed division overflow - return left.divS(right); - } - case DivUInt64: { - if (right.getInteger() == 0) trap("i64.div_u by 0"); - return left.divU(right); - } - case RemSInt64: { - if (right.getInteger() == 0) trap("i64.rem_s by 0"); - if (left.getInteger() == LLONG_MIN && right.getInteger() == -1LL) return Literal(int64_t(0)); - return left.remS(right); - } - case RemUInt64: { - if (right.getInteger() == 0) trap("i64.rem_u by 0"); - return left.remU(right); - } - case AndInt64: return left.and_(right); - case OrInt64: return left.or_(right); - case XorInt64: return left.xor_(right); - case ShlInt64: return left.shl(right.and_(Literal(int64_t(63)))); - case ShrUInt64: return left.shrU(right.and_(Literal(int64_t(63)))); - case ShrSInt64: return left.shrS(right.and_(Literal(int64_t(63)))); - case RotLInt64: return left.rotL(right); - case RotRInt64: return left.rotR(right); - case EqInt64: return left.eq(right); - case NeInt64: return left.ne(right); - case LtSInt64: return left.ltS(right); - case LtUInt64: return left.ltU(right); - case LeSInt64: return left.leS(right); - case LeUInt64: return left.leU(right); - case GtSInt64: return left.gtS(right); - case GtUInt64: return left.gtU(right); - case GeSInt64: return left.geS(right); - case GeUInt64: return left.geU(right); - default: WASM_UNREACHABLE(); + case DivUInt32: { + if (right.getInteger() == 0) trap("i32.div_u by 0"); + return left.divU(right); } - } else if (left.type == f32 || left.type == f64) { - switch (curr->op) { - case AddFloat32: case AddFloat64: return left.add(right); - case SubFloat32: case SubFloat64: return left.sub(right); - case MulFloat32: case MulFloat64: return left.mul(right); - case DivFloat32: case DivFloat64: return left.div(right); - case CopySignFloat32: case CopySignFloat64: return left.copysign(right); - case MinFloat32: case MinFloat64: return left.min(right); - case MaxFloat32: case MaxFloat64: return left.max(right); - case EqFloat32: case EqFloat64: return left.eq(right); - case NeFloat32: case NeFloat64: return left.ne(right); - case LtFloat32: case LtFloat64: return left.lt(right); - case LeFloat32: case LeFloat64: return left.le(right); - case GtFloat32: case GtFloat64: return left.gt(right); - case GeFloat32: case GeFloat64: return left.ge(right); - default: WASM_UNREACHABLE(); + case RemSInt32: { + if (right.getInteger() == 0) trap("i32.rem_s by 0"); + if (left.getInteger() == std::numeric_limits<int32_t>::min() && right.getInteger() == -1) return Literal(int32_t(0)); + return left.remS(right); + } + case RemUInt32: { + if (right.getInteger() == 0) trap("i32.rem_u by 0"); + return left.remU(right); + } + case DivSInt64: { + if (right.getInteger() == 0) trap("i64.div_s by 0"); + if (left.getInteger() == LLONG_MIN && right.getInteger() == -1LL) trap("i64.div_s overflow"); // signed division overflow + return left.divS(right); } + case DivUInt64: { + if (right.getInteger() == 0) trap("i64.div_u by 0"); + return left.divU(right); + } + case RemSInt64: { + if (right.getInteger() == 0) trap("i64.rem_s by 0"); + if (left.getInteger() == LLONG_MIN && right.getInteger() == -1LL) return Literal(int64_t(0)); + return left.remS(right); + } + case RemUInt64: { + if (right.getInteger() == 0) trap("i64.rem_u by 0"); + return left.remU(right); + } + case DivFloat32: + case DivFloat64: return left.div(right); + case AndInt32: + case AndInt64: return left.and_(right); + case OrInt32: + case OrInt64: return left.or_(right); + case XorInt32: + case XorInt64: return left.xor_(right); + case ShlInt32: + case ShlInt64: return left.shl(right); + case ShrUInt32: + case ShrUInt64: return left.shrU(right); + case ShrSInt32: + case ShrSInt64: return left.shrS(right); + case RotLInt32: + case RotLInt64: return left.rotL(right); + case RotRInt32: + case RotRInt64: return left.rotR(right); + + case EqInt32: + case EqInt64: + case EqFloat32: + case EqFloat64: return left.eq(right); + case NeInt32: + case NeInt64: + case NeFloat32: + case NeFloat64: return left.ne(right); + case LtSInt32: + case LtSInt64: return left.ltS(right); + case LtUInt32: + case LtUInt64: return left.ltU(right); + case LeSInt32: + case LeSInt64: return left.leS(right); + case LeUInt32: + case LeUInt64: return left.leU(right); + case GtSInt32: + case GtSInt64: return left.gtS(right); + case GtUInt32: + case GtUInt64: return left.gtU(right); + case GeSInt32: + case GeSInt64: return left.geS(right); + case GeUInt32: + case GeUInt64: return left.geU(right); + case LtFloat32: + case LtFloat64: return left.lt(right); + case LeFloat32: + case LeFloat64: return left.le(right); + case GtFloat32: + case GtFloat64: return left.gt(right); + case GeFloat32: + case GeFloat64: return left.ge(right); + + case CopySignFloat32: + case CopySignFloat64: return left.copysign(right); + case MinFloat32: + case MinFloat64: return left.min(right); + case MaxFloat32: + case MaxFloat64: return left.max(right); + default: WASM_UNREACHABLE(); } - WASM_UNREACHABLE(); } Flow visitSelect(Select *curr) { NOTE_ENTER("Select"); diff --git a/src/wasm/literal.cpp b/src/wasm/literal.cpp index a6dcd17f0..e68dd5d5f 100644 --- a/src/wasm/literal.cpp +++ b/src/wasm/literal.cpp @@ -22,6 +22,8 @@ #include "emscripten-optimizer/simple_ast.h" #include "pretty_printing.h" #include "support/bits.h" +#include "ir/bits.h" + namespace wasm { @@ -213,6 +215,23 @@ Literal Literal::extendToF64() const { return Literal(double(getf32())); } +Literal Literal::extendS8() const { + if (type == Type::i32) return Literal(int32_t(int8_t(geti32() & 0xFF))); + if (type == Type::i64) return Literal(int64_t(int8_t(geti64() & 0xFF))); + WASM_UNREACHABLE(); +} + +Literal Literal::extendS16() const { + if (type == Type::i32) return Literal(int32_t(int16_t(geti32() & 0xFFFF))); + if (type == Type::i64) return Literal(int64_t(int16_t(geti64() & 0xFFFF))); + WASM_UNREACHABLE(); +} + +Literal Literal::extendS32() const { + if (type == Type::i64) return Literal(int64_t(int32_t(geti64() & 0xFFFFFFFF))); + WASM_UNREACHABLE(); +} + Literal Literal::truncateToI32() const { assert(type == Type::i64); return Literal((int32_t)i64); @@ -247,6 +266,16 @@ Literal Literal::convertUToF64() const { WASM_UNREACHABLE(); } +Literal Literal::eqz() const { + switch (type) { + case Type::i32: return eq(Literal(int32_t(0))); + case Type::i64: return eq(Literal(int64_t(0))); + case Type::f32: return eq(Literal(float(0))); + case Type::f64: return eq(Literal(double(0))); + default: WASM_UNREACHABLE(); + } +} + Literal Literal::neg() const { switch (type) { case Type::i32: return Literal(-uint32_t(i32)); @@ -307,6 +336,21 @@ Literal Literal::sqrt() const { } } +Literal Literal::demote() const { + auto f64 = getf64(); + if (std::isnan(f64)) return Literal(float(f64)); + if (std::isinf(f64)) return Literal(float(f64)); + // when close to the limit, but still truncatable to a valid value, do that + // see https://github.com/WebAssembly/sexpr-wasm-prototype/blob/2d375e8d502327e814d62a08f22da9d9b6b675dc/src/wasm-interpreter.c#L247 + uint64_t bits = reinterpreti64(); + if (bits > 0x47efffffe0000000ULL && bits < 0x47effffff0000000ULL) return Literal(std::numeric_limits<float>::max()); + if (bits > 0xc7efffffe0000000ULL && bits < 0xc7effffff0000000ULL) return Literal(-std::numeric_limits<float>::max()); + // when we must convert to infinity, do that + if (f64 < -std::numeric_limits<float>::max()) return Literal(-std::numeric_limits<float>::infinity()); + if (f64 > std::numeric_limits<float>::max()) return Literal(std::numeric_limits<float>::infinity()); + return truncateToF32(); +} + Literal Literal::add(const Literal& other) const { switch (type) { case Type::i32: return Literal(uint32_t(i32) + uint32_t(other.i32)); @@ -441,24 +485,24 @@ Literal Literal::xor_(const Literal& other) const { Literal Literal::shl(const Literal& other) const { switch (type) { - case Type::i32: return Literal(uint32_t(i32) << shiftMask(other.i32)); - case Type::i64: return Literal(uint64_t(i64) << shiftMask(other.i64)); + case Type::i32: return Literal(uint32_t(i32) << Bits::getEffectiveShifts(other.i32, Type::i32)); + case Type::i64: return Literal(uint64_t(i64) << Bits::getEffectiveShifts(other.i64, Type::i64)); default: WASM_UNREACHABLE(); } } Literal Literal::shrS(const Literal& other) const { switch (type) { - case Type::i32: return Literal(i32 >> shiftMask(other.i32)); - case Type::i64: return Literal(i64 >> shiftMask(other.i64)); + case Type::i32: return Literal(i32 >> Bits::getEffectiveShifts(other.i32, Type::i32)); + case Type::i64: return Literal(i64 >> Bits::getEffectiveShifts(other.i64, Type::i64)); default: WASM_UNREACHABLE(); } } Literal Literal::shrU(const Literal& other) const { switch (type) { - case Type::i32: return Literal(uint32_t(i32) >> shiftMask(other.i32)); - case Type::i64: return Literal(uint64_t(i64) >> shiftMask(other.i64)); + case Type::i32: return Literal(uint32_t(i32) >> Bits::getEffectiveShifts(other.i32, Type::i32)); + case Type::i64: return Literal(uint64_t(i64) >> Bits::getEffectiveShifts(other.i64, Type::i64)); default: WASM_UNREACHABLE(); } } |