diff options
author | ivarflakstad <69173633+ivarflakstad@users.noreply.github.com> | 2024-04-14 20:01:13 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-14 20:01:13 +0200 |
commit | db7dbf3071e2da47314096bb04d8f4d99626d4ca (patch) | |
tree | 4085761a4cfee3a7997197a200115d6a58651d42 /candle-core/src/metal_backend | |
parent | 4ecedb15981c7141df789db597140ed96c89e7dd (diff) | |
download | candle-db7dbf3071e2da47314096bb04d8f4d99626d4ca.tar.gz candle-db7dbf3071e2da47314096bb04d8f4d99626d4ca.tar.bz2 candle-db7dbf3071e2da47314096bb04d8f4d99626d4ca.zip |
Add missing bfloat unary strided kernels and fix typo (#2058)
Diffstat (limited to 'candle-core/src/metal_backend')
-rw-r--r-- | candle-core/src/metal_backend/mod.rs | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 6f1f64ee..daa68c39 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -540,6 +540,7 @@ impl BackendStorage for MetalStorage { ("urelu", DType::F32) => strided::relu::FLOAT, ("uround", DType::F32) => strided::round::FLOAT, ("utanh", DType::F32) => strided::tanh::FLOAT, + ("ucos", DType::F16) => strided::cos::HALF, ("usin", DType::F16) => strided::sin::HALF, ("usqr", DType::F16) => strided::sqr::HALF, @@ -557,6 +558,25 @@ impl BackendStorage for MetalStorage { ("urelu", DType::F16) => strided::relu::HALF, ("uround", DType::F16) => strided::round::HALF, ("utanh", DType::F16) => strided::tanh::HALF, + + ("ucos", DType::BF16) => strided::cos::BFLOAT, + ("usin", DType::BF16) => strided::sin::BFLOAT, + ("usqr", DType::BF16) => strided::sqr::BFLOAT, + ("usqrt", DType::BF16) => strided::sqrt::BFLOAT, + ("uneg", DType::BF16) => strided::neg::BFLOAT, + ("uexp", DType::BF16) => strided::exp::BFLOAT, + ("ulog", DType::BF16) => strided::log::BFLOAT, + ("ugelu", DType::BF16) => strided::gelu::BFLOAT, + ("ugelu_erf", DType::BF16) => strided::gelu_erf::BFLOAT, + ("uerf", DType::BF16) => strided::erf::BFLOAT, + ("usilu", DType::BF16) => strided::silu::BFLOAT, + ("uabs", DType::BF16) => strided::abs::BFLOAT, + ("uceil", DType::BF16) => strided::ceil::BFLOAT, + ("ufloor", DType::BF16) => strided::floor::BFLOAT, + ("urelu", DType::BF16) => strided::relu::BFLOAT, + ("uround", DType::BF16) => strided::round::BFLOAT, + ("utanh", DType::BF16) => strided::tanh::BFLOAT, + (name, dtype) => { crate::bail!("Metal strided unary {name} {dtype:?} not implemented") } |