summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/metal_backend.rs4
-rw-r--r--candle-metal-kernels/src/lib.rs5
-rw-r--r--candle-metal-kernels/src/unary.metal1
3 files changed, 9 insertions, 1 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs
index 7a22595e..e168c24b 100644
--- a/candle-core/src/metal_backend.rs
+++ b/candle-core/src/metal_backend.rs
@@ -665,6 +665,7 @@ impl BackendStorage for MetalStorage {
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
("uerf", DType::F32) => contiguous::erf::FLOAT,
+ ("uabs", DType::F32) => contiguous::abs::FLOAT,
("uceil", DType::F32) => contiguous::ceil::FLOAT,
("ufloor", DType::F32) => contiguous::floor::FLOAT,
("uround", DType::F32) => contiguous::round::FLOAT,
@@ -680,6 +681,7 @@ impl BackendStorage for MetalStorage {
("ugelu", DType::F16) => contiguous::gelu::HALF,
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
("uerf", DType::F16) => contiguous::erf::HALF,
+ ("uabs", DType::F16) => contiguous::abs::HALF,
("uceil", DType::F16) => contiguous::ceil::HALF,
("ufloor", DType::F16) => contiguous::floor::HALF,
("uround", DType::F16) => contiguous::round::HALF,
@@ -712,6 +714,7 @@ impl BackendStorage for MetalStorage {
("ugelu", DType::F32) => strided::gelu::FLOAT,
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
("uerf", DType::F32) => strided::erf::FLOAT,
+ ("uabs", DType::F32) => strided::abs::FLOAT,
("uceil", DType::F32) => strided::ceil::FLOAT,
("ufloor", DType::F32) => strided::floor::FLOAT,
("uround", DType::F32) => strided::round::FLOAT,
@@ -725,6 +728,7 @@ impl BackendStorage for MetalStorage {
("ugelu", DType::F16) => strided::gelu::HALF,
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
("uerf", DType::F16) => strided::erf::HALF,
+ ("uabs", DType::F16) => strided::abs::HALF,
("uceil", DType::F16) => strided::ceil::HALF,
("ufloor", DType::F16) => strided::floor::HALF,
("uround", DType::F16) => strided::round::HALF,
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index d080ef52..5d34f61a 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -173,7 +173,10 @@ macro_rules! ops{
}
pub mod unary {
- ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh, recip);
+ ops!(
+ cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, round, erf, gelu_erf, tanh,
+ recip
+ );
}
pub mod binary {
ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt);
diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal
index 15d1e400..7fbb613d 100644
--- a/candle-metal-kernels/src/unary.metal
+++ b/candle-metal-kernels/src/unary.metal
@@ -102,6 +102,7 @@ UNARY_OP(neg)
UNARY_OP(exp)
UNARY_OP(log)
UNARY_OP(gelu)
+UNARY_OP(abs)
UNARY_OP(ceil)
UNARY_OP(floor)
UNARY_OP(round)