summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/metal_backend.rs3
-rw-r--r--candle-core/src/tensor.rs12
2 files changed, 14 insertions, 1 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs
index 48250233..8a75bd7c 100644
--- a/candle-core/src/metal_backend.rs
+++ b/candle-core/src/metal_backend.rs
@@ -355,6 +355,7 @@ impl BackendStorage for MetalStorage {
let name = match self.dtype {
DType::F32 => "affine_f32",
DType::F16 => "affine_f16",
+ DType::BF16 => "affine_bf16",
dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"),
};
candle_metal_kernels::call_affine(
@@ -373,6 +374,7 @@ impl BackendStorage for MetalStorage {
let name = match self.dtype {
DType::F32 => "affine_f32_strided",
DType::F16 => "affine_f16_strided",
+ DType::BF16 => "affine_bf16_strided",
dtype => crate::bail!("Metal strided affine {dtype:?} not implemented"),
};
candle_metal_kernels::call_affine_strided(
@@ -808,6 +810,7 @@ impl BackendStorage for MetalStorage {
}
let name = match (self.dtype, t.dtype()) {
(DType::U8, DType::F32) => "where_u8_f32",
+ (DType::U8, DType::BF16) => "where_u8_bf16",
(DType::U8, DType::F16) => "where_u8_f16",
(DType::U8, DType::I64) => "where_u8_i64",
(DType::U8, DType::U32) => "where_u8_u32",
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 54f9fa2b..3100c6e8 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -2578,11 +2578,21 @@ impl Tensor {
}
/// Returns log(sum(exp(tensor), dim)).
- pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
+ pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
let exp = self.exp()?;
let sum = exp.sum(sum_dims)?;
sum.log()
}
+
+ /// Pointwise pow operation.
+ pub fn pow(&self, rhs: &Tensor) -> Result<Self> {
+ rhs.mul(&self.log()?)?.exp()
+ }
+
+ /// Broadcasting version of `pow`.
+ pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
+ rhs.broadcast_mul(&self.log()?)?.exp()
+ }
}
macro_rules! bin_trait {