diff options
-rw-r--r-- | candle-core/src/metal_backend.rs | 4 | ||||
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 5 | ||||
-rw-r--r-- | candle-metal-kernels/src/unary.metal | 1 |
3 files changed, 9 insertions, 1 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 7a22595e..e168c24b 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -665,6 +665,7 @@ impl BackendStorage for MetalStorage { ("ugelu", DType::F32) => contiguous::gelu::FLOAT, ("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT, ("uerf", DType::F32) => contiguous::erf::FLOAT, + ("uabs", DType::F32) => contiguous::abs::FLOAT, ("uceil", DType::F32) => contiguous::ceil::FLOAT, ("ufloor", DType::F32) => contiguous::floor::FLOAT, ("uround", DType::F32) => contiguous::round::FLOAT, @@ -680,6 +681,7 @@ impl BackendStorage for MetalStorage { ("ugelu", DType::F16) => contiguous::gelu::HALF, ("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF, ("uerf", DType::F16) => contiguous::erf::HALF, + ("uabs", DType::F16) => contiguous::abs::HALF, ("uceil", DType::F16) => contiguous::ceil::HALF, ("ufloor", DType::F16) => contiguous::floor::HALF, ("uround", DType::F16) => contiguous::round::HALF, @@ -712,6 +714,7 @@ impl BackendStorage for MetalStorage { ("ugelu", DType::F32) => strided::gelu::FLOAT, ("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT, ("uerf", DType::F32) => strided::erf::FLOAT, + ("uabs", DType::F32) => strided::abs::FLOAT, ("uceil", DType::F32) => strided::ceil::FLOAT, ("ufloor", DType::F32) => strided::floor::FLOAT, ("uround", DType::F32) => strided::round::FLOAT, @@ -725,6 +728,7 @@ impl BackendStorage for MetalStorage { ("ugelu", DType::F16) => strided::gelu::HALF, ("ugelu_erf", DType::F16) => strided::gelu_erf::HALF, ("uerf", DType::F16) => strided::erf::HALF, + ("uabs", DType::F16) => strided::abs::HALF, ("uceil", DType::F16) => strided::ceil::HALF, ("ufloor", DType::F16) => strided::floor::HALF, ("uround", DType::F16) => strided::round::HALF, diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index d080ef52..5d34f61a 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -173,7 +173,10 @@ macro_rules! ops{ } pub mod unary { - ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh, recip); + ops!( + cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, round, erf, gelu_erf, tanh, + recip + ); } pub mod binary { ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt); diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 15d1e400..7fbb613d 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -102,6 +102,7 @@ UNARY_OP(neg) UNARY_OP(exp) UNARY_OP(log) UNARY_OP(gelu) +UNARY_OP(abs) UNARY_OP(ceil) UNARY_OP(floor) UNARY_OP(round) |