summaryrefslogtreecommitdiff
path: root/src/wasm
diff options
context:
space:
mode:
authorBrendan Dahl <brendan.dahl@gmail.com>2024-08-27 11:14:54 -0700
committerGitHub <noreply@github.com>2024-08-27 11:14:54 -0700
commit6c2d0e20906248ab8f8365702b35fd67db29c44f (patch)
tree17bdd816e894b7872b90e4a81262e3adb65de1af /src/wasm
parent459bc0797f67cb2a8fd4598bb7143b34036608d9 (diff)
downloadbinaryen-6c2d0e20906248ab8f8365702b35fd67db29c44f.tar.gz
binaryen-6c2d0e20906248ab8f8365702b35fd67db29c44f.tar.bz2
binaryen-6c2d0e20906248ab8f8365702b35fd67db29c44f.zip
[FP16] Implement unary operations. (#6867)
Specified at https://github.com/WebAssembly/half-precision/blob/main/proposals/half-precision/Overview.md
Diffstat (limited to 'src/wasm')
-rw-r--r--src/wasm/literal.cpp36
-rw-r--r--src/wasm/wasm-binary.cpp28
-rw-r--r--src/wasm/wasm-stack.cpp22
-rw-r--r--src/wasm/wasm-validator.cpp11
-rw-r--r--src/wasm/wasm.cpp7
5 files changed, 97 insertions, 7 deletions
diff --git a/src/wasm/literal.cpp b/src/wasm/literal.cpp
index 6a4614a90..1b84ba53c 100644
--- a/src/wasm/literal.cpp
+++ b/src/wasm/literal.cpp
@@ -1842,13 +1842,19 @@ Literal Literal::replaceLaneF64x2(const Literal& other, uint8_t index) const {
return replace<2, &Literal::getLanesF64x2>(*this, other, index);
}
+static Literal passThrough(const Literal& literal) { return literal; }
+static Literal toFP16(const Literal& literal) {
+ return literal.convertF32ToF16();
+}
+
template<int Lanes,
LaneArray<Lanes> (Literal::*IntoLanes)() const,
- Literal (Literal::*UnaryOp)(void) const>
+ Literal (Literal::*UnaryOp)(void) const,
+ Literal (*Convert)(const Literal&) = passThrough>
static Literal unary(const Literal& val) {
LaneArray<Lanes> lanes = (val.*IntoLanes)();
for (size_t i = 0; i < Lanes; ++i) {
- lanes[i] = (lanes[i].*UnaryOp)();
+ lanes[i] = Convert((lanes[i].*UnaryOp)());
}
return Literal(lanes);
}
@@ -1885,6 +1891,27 @@ Literal Literal::negI32x4() const {
Literal Literal::negI64x2() const {
return unary<2, &Literal::getLanesI64x2, &Literal::neg>(*this);
}
+Literal Literal::absF16x8() const {
+ return unary<8, &Literal::getLanesF16x8, &Literal::abs, &toFP16>(*this);
+}
+Literal Literal::negF16x8() const {
+ return unary<8, &Literal::getLanesF16x8, &Literal::neg, &toFP16>(*this);
+}
+Literal Literal::sqrtF16x8() const {
+ return unary<8, &Literal::getLanesF16x8, &Literal::sqrt, &toFP16>(*this);
+}
+Literal Literal::ceilF16x8() const {
+ return unary<8, &Literal::getLanesF16x8, &Literal::ceil, &toFP16>(*this);
+}
+Literal Literal::floorF16x8() const {
+ return unary<8, &Literal::getLanesF16x8, &Literal::floor, &toFP16>(*this);
+}
+Literal Literal::truncF16x8() const {
+ return unary<8, &Literal::getLanesF16x8, &Literal::trunc, &toFP16>(*this);
+}
+Literal Literal::nearestF16x8() const {
+ return unary<8, &Literal::getLanesF16x8, &Literal::nearbyint, &toFP16>(*this);
+}
Literal Literal::absF32x4() const {
return unary<4, &Literal::getLanesF32x4, &Literal::abs>(*this);
}
@@ -2271,11 +2298,6 @@ Literal Literal::geF64x2(const Literal& other) const {
other);
}
-static Literal passThrough(const Literal& literal) { return literal; }
-static Literal toFP16(const Literal& literal) {
- return literal.convertF32ToF16();
-}
-
template<int Lanes,
LaneArray<Lanes> (Literal::*IntoLanes)() const,
Literal (Literal::*BinaryOp)(const Literal&) const,
diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp
index 06d17fdfb..16a926182 100644
--- a/src/wasm/wasm-binary.cpp
+++ b/src/wasm/wasm-binary.cpp
@@ -6439,6 +6439,34 @@ bool WasmBinaryReader::maybeVisitSIMDUnary(Expression*& out, uint32_t code) {
curr = allocator.alloc<Unary>();
curr->op = BitmaskVecI64x2;
break;
+ case BinaryConsts::F16x8Abs:
+ curr = allocator.alloc<Unary>();
+ curr->op = AbsVecF16x8;
+ break;
+ case BinaryConsts::F16x8Neg:
+ curr = allocator.alloc<Unary>();
+ curr->op = NegVecF16x8;
+ break;
+ case BinaryConsts::F16x8Sqrt:
+ curr = allocator.alloc<Unary>();
+ curr->op = SqrtVecF16x8;
+ break;
+ case BinaryConsts::F16x8Ceil:
+ curr = allocator.alloc<Unary>();
+ curr->op = CeilVecF16x8;
+ break;
+ case BinaryConsts::F16x8Floor:
+ curr = allocator.alloc<Unary>();
+ curr->op = FloorVecF16x8;
+ break;
+ case BinaryConsts::F16x8Trunc:
+ curr = allocator.alloc<Unary>();
+ curr->op = TruncVecF16x8;
+ break;
+ case BinaryConsts::F16x8Nearest:
+ curr = allocator.alloc<Unary>();
+ curr->op = NearestVecF16x8;
+ break;
case BinaryConsts::F32x4Abs:
curr = allocator.alloc<Unary>();
curr->op = AbsVecF32x4;
diff --git a/src/wasm/wasm-stack.cpp b/src/wasm/wasm-stack.cpp
index b7bfea617..59593ddde 100644
--- a/src/wasm/wasm-stack.cpp
+++ b/src/wasm/wasm-stack.cpp
@@ -1132,6 +1132,28 @@ void BinaryInstWriter::visitUnary(Unary* curr) {
o << int8_t(BinaryConsts::SIMDPrefix)
<< U32LEB(BinaryConsts::I64x2Bitmask);
break;
+ case AbsVecF16x8:
+ o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::F16x8Abs);
+ break;
+ case NegVecF16x8:
+ o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::F16x8Neg);
+ break;
+ case SqrtVecF16x8:
+ o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::F16x8Sqrt);
+ break;
+ case CeilVecF16x8:
+ o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::F16x8Ceil);
+ break;
+ case FloorVecF16x8:
+ o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::F16x8Floor);
+ break;
+ case TruncVecF16x8:
+ o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::F16x8Trunc);
+ break;
+ case NearestVecF16x8:
+ o << int8_t(BinaryConsts::SIMDPrefix)
+ << U32LEB(BinaryConsts::F16x8Nearest);
+ break;
case AbsVecF32x4:
o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::F32x4Abs);
break;
diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp
index f77eeefe7..40726d7cd 100644
--- a/src/wasm/wasm-validator.cpp
+++ b/src/wasm/wasm-validator.cpp
@@ -2095,6 +2095,17 @@ void FunctionValidator::visitUnary(Unary* curr) {
shouldBeEqual(
curr->value->type, Type(Type::f64), curr, "expected f64 splat value");
break;
+ case AbsVecF16x8:
+ case NegVecF16x8:
+ case SqrtVecF16x8:
+ case CeilVecF16x8:
+ case FloorVecF16x8:
+ case TruncVecF16x8:
+ case NearestVecF16x8:
+ shouldBeTrue(getModule()->features.hasFP16(),
+ curr,
+ "FP16 operations require FP16 [--enable-fp16]");
+ [[fallthrough]];
case NotVec128:
case PopcntVecI8x16:
case AbsVecI8x16:
diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp
index e768f0dc4..98146dfbc 100644
--- a/src/wasm/wasm.cpp
+++ b/src/wasm/wasm.cpp
@@ -652,6 +652,13 @@ void Unary::finalize() {
case NegVecI16x8:
case NegVecI32x4:
case NegVecI64x2:
+ case AbsVecF16x8:
+ case NegVecF16x8:
+ case SqrtVecF16x8:
+ case CeilVecF16x8:
+ case FloorVecF16x8:
+ case TruncVecF16x8:
+ case NearestVecF16x8:
case AbsVecF32x4:
case NegVecF32x4:
case SqrtVecF32x4: