summaryrefslogtreecommitdiff
path: root/candle-core/src/metal_backend
diff options
context:
space:
mode:
authorivarflakstad <69173633+ivarflakstad@users.noreply.github.com>2024-04-14 20:01:13 +0200
committerGitHub <noreply@github.com>2024-04-14 20:01:13 +0200
commitdb7dbf3071e2da47314096bb04d8f4d99626d4ca (patch)
tree4085761a4cfee3a7997197a200115d6a58651d42 /candle-core/src/metal_backend
parent4ecedb15981c7141df789db597140ed96c89e7dd (diff)
downloadcandle-db7dbf3071e2da47314096bb04d8f4d99626d4ca.tar.gz
candle-db7dbf3071e2da47314096bb04d8f4d99626d4ca.tar.bz2
candle-db7dbf3071e2da47314096bb04d8f4d99626d4ca.zip
Add missing bfloat unary strided kernels and fix typo (#2058)
Diffstat (limited to 'candle-core/src/metal_backend')
-rw-r--r--candle-core/src/metal_backend/mod.rs20
1 files changed, 20 insertions, 0 deletions
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs
index 6f1f64ee..daa68c39 100644
--- a/candle-core/src/metal_backend/mod.rs
+++ b/candle-core/src/metal_backend/mod.rs
@@ -540,6 +540,7 @@ impl BackendStorage for MetalStorage {
("urelu", DType::F32) => strided::relu::FLOAT,
("uround", DType::F32) => strided::round::FLOAT,
("utanh", DType::F32) => strided::tanh::FLOAT,
+
("ucos", DType::F16) => strided::cos::HALF,
("usin", DType::F16) => strided::sin::HALF,
("usqr", DType::F16) => strided::sqr::HALF,
@@ -557,6 +558,25 @@ impl BackendStorage for MetalStorage {
("urelu", DType::F16) => strided::relu::HALF,
("uround", DType::F16) => strided::round::HALF,
("utanh", DType::F16) => strided::tanh::HALF,
+
+ ("ucos", DType::BF16) => strided::cos::BFLOAT,
+ ("usin", DType::BF16) => strided::sin::BFLOAT,
+ ("usqr", DType::BF16) => strided::sqr::BFLOAT,
+ ("usqrt", DType::BF16) => strided::sqrt::BFLOAT,
+ ("uneg", DType::BF16) => strided::neg::BFLOAT,
+ ("uexp", DType::BF16) => strided::exp::BFLOAT,
+ ("ulog", DType::BF16) => strided::log::BFLOAT,
+ ("ugelu", DType::BF16) => strided::gelu::BFLOAT,
+ ("ugelu_erf", DType::BF16) => strided::gelu_erf::BFLOAT,
+ ("uerf", DType::BF16) => strided::erf::BFLOAT,
+ ("usilu", DType::BF16) => strided::silu::BFLOAT,
+ ("uabs", DType::BF16) => strided::abs::BFLOAT,
+ ("uceil", DType::BF16) => strided::ceil::BFLOAT,
+ ("ufloor", DType::BF16) => strided::floor::BFLOAT,
+ ("urelu", DType::BF16) => strided::relu::BFLOAT,
+ ("uround", DType::BF16) => strided::round::BFLOAT,
+ ("utanh", DType::BF16) => strided::tanh::BFLOAT,
+
(name, dtype) => {
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
}