diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/metal_backend.rs | 3 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 12 |
2 files changed, 14 insertions, 1 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 48250233..8a75bd7c 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -355,6 +355,7 @@ impl BackendStorage for MetalStorage { let name = match self.dtype { DType::F32 => "affine_f32", DType::F16 => "affine_f16", + DType::BF16 => "affine_bf16", dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"), }; candle_metal_kernels::call_affine( @@ -373,6 +374,7 @@ impl BackendStorage for MetalStorage { let name = match self.dtype { DType::F32 => "affine_f32_strided", DType::F16 => "affine_f16_strided", + DType::BF16 => "affine_bf16_strided", dtype => crate::bail!("Metal strided affine {dtype:?} not implemented"), }; candle_metal_kernels::call_affine_strided( @@ -808,6 +810,7 @@ impl BackendStorage for MetalStorage { } let name = match (self.dtype, t.dtype()) { (DType::U8, DType::F32) => "where_u8_f32", + (DType::U8, DType::BF16) => "where_u8_bf16", (DType::U8, DType::F16) => "where_u8_f16", (DType::U8, DType::I64) => "where_u8_i64", (DType::U8, DType::U32) => "where_u8_u32", diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 54f9fa2b..3100c6e8 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2578,11 +2578,21 @@ impl Tensor { } /// Returns log(sum(exp(tensor), dim)). - pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> { + pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> { let exp = self.exp()?; let sum = exp.sum(sum_dims)?; sum.log() } + + /// Pointwise pow operation. + pub fn pow(&self, rhs: &Tensor) -> Result<Self> { + rhs.mul(&self.log()?)?.exp() + } + + /// Broadcasting version of `pow`. + pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> { + rhs.broadcast_mul(&self.log()?)?.exp() + } } macro_rules! bin_trait { |