summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/metal_backend/mod.rs20
-rw-r--r--candle-metal-kernels/src/unary.metal2
2 files changed, 21 insertions, 1 deletions
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs
index 6f1f64ee..daa68c39 100644
--- a/candle-core/src/metal_backend/mod.rs
+++ b/candle-core/src/metal_backend/mod.rs
@@ -540,6 +540,7 @@ impl BackendStorage for MetalStorage {
("urelu", DType::F32) => strided::relu::FLOAT,
("uround", DType::F32) => strided::round::FLOAT,
("utanh", DType::F32) => strided::tanh::FLOAT,
+
("ucos", DType::F16) => strided::cos::HALF,
("usin", DType::F16) => strided::sin::HALF,
("usqr", DType::F16) => strided::sqr::HALF,
@@ -557,6 +558,25 @@ impl BackendStorage for MetalStorage {
("urelu", DType::F16) => strided::relu::HALF,
("uround", DType::F16) => strided::round::HALF,
("utanh", DType::F16) => strided::tanh::HALF,
+
+ ("ucos", DType::BF16) => strided::cos::BFLOAT,
+ ("usin", DType::BF16) => strided::sin::BFLOAT,
+ ("usqr", DType::BF16) => strided::sqr::BFLOAT,
+ ("usqrt", DType::BF16) => strided::sqrt::BFLOAT,
+ ("uneg", DType::BF16) => strided::neg::BFLOAT,
+ ("uexp", DType::BF16) => strided::exp::BFLOAT,
+ ("ulog", DType::BF16) => strided::log::BFLOAT,
+ ("ugelu", DType::BF16) => strided::gelu::BFLOAT,
+ ("ugelu_erf", DType::BF16) => strided::gelu_erf::BFLOAT,
+ ("uerf", DType::BF16) => strided::erf::BFLOAT,
+ ("usilu", DType::BF16) => strided::silu::BFLOAT,
+ ("uabs", DType::BF16) => strided::abs::BFLOAT,
+ ("uceil", DType::BF16) => strided::ceil::BFLOAT,
+ ("ufloor", DType::BF16) => strided::floor::BFLOAT,
+ ("urelu", DType::BF16) => strided::relu::BFLOAT,
+ ("uround", DType::BF16) => strided::round::BFLOAT,
+ ("utanh", DType::BF16) => strided::tanh::BFLOAT,
+
(name, dtype) => {
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
}
diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal
index 4b6363ed..ec793eae 100644
--- a/candle-metal-kernels/src/unary.metal
+++ b/candle-metal-kernels/src/unary.metal
@@ -175,5 +175,5 @@ BFLOAT_UNARY_OP(sign)
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
-COPY2D(copy2d_bf64, bfloat)
+COPY2D(copy2d_bf16, bfloat)
#endif