diff options
author | Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> | 2024-01-14 17:10:54 +0100 |
---|---|---|
committer | Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> | 2024-01-14 17:10:54 +0100 |
commit | ecf88a6d381e40c8db1c643dff2753fd877fae92 (patch) | |
tree | 9b6db2ec9a37a48185f323ab4c5e8b0baaa20221 /candle-core/src | |
parent | e06e8d0dbea3a052195f4ca27fb5ddcdbf1cd30c (diff) | |
parent | e6d86b081980196745e5f0b0eda8ce5334c0ff67 (diff) | |
download | candle-ecf88a6d381e40c8db1c643dff2753fd877fae92.tar.gz candle-ecf88a6d381e40c8db1c643dff2753fd877fae92.tar.bz2 candle-ecf88a6d381e40c8db1c643dff2753fd877fae92.zip |
Merge branch 'main' into ivarflakstad/metal-prng
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/metal_backend.rs | 3 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 12 |
2 files changed, 14 insertions, 1 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 48250233..8a75bd7c 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -355,6 +355,7 @@ impl BackendStorage for MetalStorage { let name = match self.dtype { DType::F32 => "affine_f32", DType::F16 => "affine_f16", + DType::BF16 => "affine_bf16", dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"), }; candle_metal_kernels::call_affine( @@ -373,6 +374,7 @@ impl BackendStorage for MetalStorage { let name = match self.dtype { DType::F32 => "affine_f32_strided", DType::F16 => "affine_f16_strided", + DType::BF16 => "affine_bf16_strided", dtype => crate::bail!("Metal strided affine {dtype:?} not implemented"), }; candle_metal_kernels::call_affine_strided( @@ -808,6 +810,7 @@ impl BackendStorage for MetalStorage { } let name = match (self.dtype, t.dtype()) { (DType::U8, DType::F32) => "where_u8_f32", + (DType::U8, DType::BF16) => "where_u8_bf16", (DType::U8, DType::F16) => "where_u8_f16", (DType::U8, DType::I64) => "where_u8_i64", (DType::U8, DType::U32) => "where_u8_u32", diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 54f9fa2b..3100c6e8 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2578,11 +2578,21 @@ impl Tensor { } /// Returns log(sum(exp(tensor), dim)). - pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> { + pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> { let exp = self.exp()?; let sum = exp.sum(sum_dims)?; sum.log() } + + /// Pointwise pow operation. + pub fn pow(&self, rhs: &Tensor) -> Result<Self> { + rhs.mul(&self.log()?)?.exp() + } + + /// Broadcasting version of `pow`. + pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> { + rhs.broadcast_mul(&self.log()?)?.exp() + } } macro_rules! bin_trait { |