diff options
Diffstat (limited to 'src/wasm')
-rw-r--r-- | src/wasm/literal.cpp | 51 | ||||
-rw-r--r-- | src/wasm/wasm-binary.cpp | 32 | ||||
-rw-r--r-- | src/wasm/wasm-stack.cpp | 24 | ||||
-rw-r--r-- | src/wasm/wasm-validator.cpp | 8 |
4 files changed, 110 insertions, 5 deletions
diff --git a/src/wasm/literal.cpp b/src/wasm/literal.cpp index d60e2f8a9..65c2b4e62 100644 --- a/src/wasm/literal.cpp +++ b/src/wasm/literal.cpp @@ -856,6 +856,10 @@ Literal Literal::convertUIToF64() const { WASM_UNREACHABLE("invalid type"); } +Literal Literal::convertF32ToF16() const { + return Literal(fp16_ieee_from_fp32_value(getf32())); +} + template<typename F> struct AsInt { using type = void; }; template<> struct AsInt<float> { using type = int32_t; }; template<> struct AsInt<double> { using type = int64_t; }; @@ -1791,8 +1795,7 @@ Literal Literal::splatI16x8() const { return splat<Type::i32, 8>(*this); } Literal Literal::splatI32x4() const { return splat<Type::i32, 4>(*this); } Literal Literal::splatI64x2() const { return splat<Type::i64, 2>(*this); } Literal Literal::splatF16x8() const { - uint16_t f16 = fp16_ieee_from_fp32_value(getf32()); - return splat<Type::i32, 8>(Literal(f16)); + return splat<Type::i32, 8>(convertF32ToF16()); } Literal Literal::splatF32x4() const { return splat<Type::f32, 4>(*this); } Literal Literal::splatF64x2() const { return splat<Type::f64, 2>(*this); } @@ -1848,7 +1851,7 @@ Literal Literal::replaceLaneI64x2(const Literal& other, uint8_t index) const { } Literal Literal::replaceLaneF16x8(const Literal& other, uint8_t index) const { return replace<8, &Literal::getLanesF16x8>( - *this, Literal(fp16_ieee_from_fp32_value(other.getf32())), index); + *this, other.convertF32ToF16(), index); } Literal Literal::replaceLaneF32x4(const Literal& other, uint8_t index) const { return replace<4, &Literal::getLanesF32x4>(*this, other, index); @@ -2286,14 +2289,20 @@ Literal Literal::geF64x2(const Literal& other) const { other); } +static Literal passThrough(const Literal& literal) { return literal; } +static Literal toFP16(const Literal& literal) { + return literal.convertF32ToF16(); +} + template<int Lanes, LaneArray<Lanes> (Literal::*IntoLanes)() const, - Literal (Literal::*BinaryOp)(const Literal&) const> + Literal (Literal::*BinaryOp)(const Literal&) const, + Literal (*Convert)(const Literal&) = passThrough> static Literal binary(const Literal& val, const Literal& other) { LaneArray<Lanes> lanes = (val.*IntoLanes)(); LaneArray<Lanes> other_lanes = (other.*IntoLanes)(); for (size_t i = 0; i < Lanes; ++i) { - lanes[i] = (lanes[i].*BinaryOp)(other_lanes[i]); + lanes[i] = Convert((lanes[i].*BinaryOp)(other_lanes[i])); } return Literal(lanes); } @@ -2418,6 +2427,38 @@ Literal Literal::subI64x2(const Literal& other) const { Literal Literal::mulI64x2(const Literal& other) const { return binary<2, &Literal::getLanesI64x2, &Literal::mul>(*this, other); } +Literal Literal::addF16x8(const Literal& other) const { + return binary<8, &Literal::getLanesF16x8, &Literal::add, &toFP16>(*this, + other); +} +Literal Literal::subF16x8(const Literal& other) const { + return binary<8, &Literal::getLanesF16x8, &Literal::sub, &toFP16>(*this, + other); +} +Literal Literal::mulF16x8(const Literal& other) const { + return binary<8, &Literal::getLanesF16x8, &Literal::mul, &toFP16>(*this, + other); +} +Literal Literal::divF16x8(const Literal& other) const { + return binary<8, &Literal::getLanesF16x8, &Literal::div, &toFP16>(*this, + other); +} +Literal Literal::minF16x8(const Literal& other) const { + return binary<8, &Literal::getLanesF16x8, &Literal::min, &toFP16>(*this, + other); +} +Literal Literal::maxF16x8(const Literal& other) const { + return binary<8, &Literal::getLanesF16x8, &Literal::max, &toFP16>(*this, + other); +} +Literal Literal::pminF16x8(const Literal& other) const { + return binary<8, &Literal::getLanesF16x8, &Literal::pmin, &toFP16>(*this, + other); +} +Literal Literal::pmaxF16x8(const Literal& other) const { + return binary<8, &Literal::getLanesF16x8, &Literal::pmax, &toFP16>(*this, + other); +} Literal Literal::addF32x4(const Literal& other) const { return binary<4, &Literal::getLanesF32x4, &Literal::add>(*this, other); } diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 865ca39ca..e84639801 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -6174,6 +6174,38 @@ bool WasmBinaryReader::maybeVisitSIMDBinary(Expression*& out, uint32_t code) { curr = allocator.alloc<Binary>(); curr->op = ExtMulHighUVecI64x2; break; + case BinaryConsts::F16x8Add: + curr = allocator.alloc<Binary>(); + curr->op = AddVecF16x8; + break; + case BinaryConsts::F16x8Sub: + curr = allocator.alloc<Binary>(); + curr->op = SubVecF16x8; + break; + case BinaryConsts::F16x8Mul: + curr = allocator.alloc<Binary>(); + curr->op = MulVecF16x8; + break; + case BinaryConsts::F16x8Div: + curr = allocator.alloc<Binary>(); + curr->op = DivVecF16x8; + break; + case BinaryConsts::F16x8Min: + curr = allocator.alloc<Binary>(); + curr->op = MinVecF16x8; + break; + case BinaryConsts::F16x8Max: + curr = allocator.alloc<Binary>(); + curr->op = MaxVecF16x8; + break; + case BinaryConsts::F16x8Pmin: + curr = allocator.alloc<Binary>(); + curr->op = PMinVecF16x8; + break; + case BinaryConsts::F16x8Pmax: + curr = allocator.alloc<Binary>(); + curr->op = PMaxVecF16x8; + break; case BinaryConsts::F32x4Add: curr = allocator.alloc<Binary>(); curr->op = AddVecF32x4; diff --git a/src/wasm/wasm-stack.cpp b/src/wasm/wasm-stack.cpp index 1c2c2c42b..b7bfea617 100644 --- a/src/wasm/wasm-stack.cpp +++ b/src/wasm/wasm-stack.cpp @@ -1873,6 +1873,30 @@ void BinaryInstWriter::visitBinary(Binary* curr) { << U32LEB(BinaryConsts::I64x2ExtmulHighI32x4U); break; + case AddVecF16x8: + o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::F16x8Add); + break; + case SubVecF16x8: + o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::F16x8Sub); + break; + case MulVecF16x8: + o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::F16x8Mul); + break; + case DivVecF16x8: + o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::F16x8Div); + break; + case MinVecF16x8: + o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::F16x8Min); + break; + case MaxVecF16x8: + o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::F16x8Max); + break; + case PMinVecF16x8: + o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::F16x8Pmin); + break; + case PMaxVecF16x8: + o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::F16x8Pmax); + break; case AddVecF32x4: o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::F32x4Add); break; diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index 0bdc18658..24f5379fb 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -1813,6 +1813,14 @@ void FunctionValidator::visitBinary(Binary* curr) { case ExtMulHighSVecI64x2: case ExtMulLowUVecI64x2: case ExtMulHighUVecI64x2: + case AddVecF16x8: + case SubVecF16x8: + case MulVecF16x8: + case DivVecF16x8: + case MinVecF16x8: + case MaxVecF16x8: + case PMinVecF16x8: + case PMaxVecF16x8: case AddVecF32x4: case SubVecF32x4: case MulVecF32x4: |