diff options
-rw-r--r-- | candle-core/src/metal_backend/mod.rs | 20 | ||||
-rw-r--r-- | candle-metal-kernels/src/unary.metal | 2 |
2 files changed, 21 insertions, 1 deletions
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 6f1f64ee..daa68c39 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -540,6 +540,7 @@ impl BackendStorage for MetalStorage { ("urelu", DType::F32) => strided::relu::FLOAT, ("uround", DType::F32) => strided::round::FLOAT, ("utanh", DType::F32) => strided::tanh::FLOAT, + ("ucos", DType::F16) => strided::cos::HALF, ("usin", DType::F16) => strided::sin::HALF, ("usqr", DType::F16) => strided::sqr::HALF, @@ -557,6 +558,25 @@ impl BackendStorage for MetalStorage { ("urelu", DType::F16) => strided::relu::HALF, ("uround", DType::F16) => strided::round::HALF, ("utanh", DType::F16) => strided::tanh::HALF, + + ("ucos", DType::BF16) => strided::cos::BFLOAT, + ("usin", DType::BF16) => strided::sin::BFLOAT, + ("usqr", DType::BF16) => strided::sqr::BFLOAT, + ("usqrt", DType::BF16) => strided::sqrt::BFLOAT, + ("uneg", DType::BF16) => strided::neg::BFLOAT, + ("uexp", DType::BF16) => strided::exp::BFLOAT, + ("ulog", DType::BF16) => strided::log::BFLOAT, + ("ugelu", DType::BF16) => strided::gelu::BFLOAT, + ("ugelu_erf", DType::BF16) => strided::gelu_erf::BFLOAT, + ("uerf", DType::BF16) => strided::erf::BFLOAT, + ("usilu", DType::BF16) => strided::silu::BFLOAT, + ("uabs", DType::BF16) => strided::abs::BFLOAT, + ("uceil", DType::BF16) => strided::ceil::BFLOAT, + ("ufloor", DType::BF16) => strided::floor::BFLOAT, + ("urelu", DType::BF16) => strided::relu::BFLOAT, + ("uround", DType::BF16) => strided::round::BFLOAT, + ("utanh", DType::BF16) => strided::tanh::BFLOAT, + (name, dtype) => { crate::bail!("Metal strided unary {name} {dtype:?} not implemented") } diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 4b6363ed..ec793eae 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -175,5 +175,5 @@ BFLOAT_UNARY_OP(sign) UNARY(id, bfloat, copy_bf16, copy_bf16_strided) -COPY2D(copy2d_bf64, bfloat) +COPY2D(copy2d_bf16, bfloat) #endif |