diff options
Diffstat (limited to 'candle-core/src/cpu_backend.rs')
-rw-r--r-- | candle-core/src/cpu_backend.rs | 111 |
1 files changed, 34 insertions, 77 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 7170e470..136eeaba 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -292,6 +292,35 @@ impl Map2 for MatMul { } } +fn divide_by_sum_over_dim<T: WithDType + num_traits::NumAssign>( + s: &mut [T], + shape: &Shape, + dim: usize, +) -> Result<()> { + // [self] stores data in a contiguous way starting at offset 0. + let dims = shape.dims(); + let elem_per_slice = dims[dim]; + let prod_pre_dim = dims[..dim].iter().product(); + let prod_post_dim = dims[dim + 1..].iter().product(); + for pre_idx in 0..prod_pre_dim { + for post_idx in 0..prod_post_dim { + let mut sum = 0f64; + let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; + for _ in 0..elem_per_slice { + sum += s[idx].to_f64(); + idx += prod_post_dim + } + let sum = T::from_f64(sum); + let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; + for _ in 0..elem_per_slice { + s[idx] /= sum; + idx += prod_post_dim + } + } + } + Ok(()) +} + impl CpuStorage { pub fn dtype(&self) -> DType { match self { @@ -437,85 +466,13 @@ impl CpuStorage { pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> { // [self] stores data in a contiguous way starting at offset 0. - let dims = shape.dims(); - let elem_per_slice = dims[dim]; - let prod_pre_dim = dims[..dim].iter().product(); - let prod_post_dim = dims[dim + 1..].iter().product(); match self { - Self::BF16(storage) => { - for pre_idx in 0..prod_pre_dim { - for post_idx in 0..prod_post_dim { - let mut sum = 0f64; - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - sum += storage[idx].to_f64(); - idx += prod_post_dim - } - let sum = bf16::from_f64(sum); - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - storage[idx] /= sum; - idx += prod_post_dim - } - } - } - } - Self::F16(storage) => { - for pre_idx in 0..prod_pre_dim { - for post_idx in 0..prod_post_dim { - let mut sum = 0f64; - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - sum += storage[idx].to_f64(); - idx += prod_post_dim - } - let sum = f16::from_f64(sum); - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - storage[idx] /= sum; - idx += prod_post_dim - } - } - } - } - Self::F32(storage) => { - for pre_idx in 0..prod_pre_dim { - for post_idx in 0..prod_post_dim { - let mut sum = 0f64; - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - sum += storage[idx] as f64; - idx += prod_post_dim - } - let sum = sum as f32; - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - storage[idx] /= sum; - idx += prod_post_dim - } - } - } - } - Self::F64(storage) => { - for pre_idx in 0..prod_pre_dim { - for post_idx in 0..prod_post_dim { - let mut sum = 0f64; - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - sum += storage[idx]; - idx += prod_post_dim - } - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - storage[idx] /= sum; - idx += prod_post_dim - } - } - } - } - Self::U32(_) => {} + Self::BF16(s) => divide_by_sum_over_dim(s, shape, dim), + Self::F16(s) => divide_by_sum_over_dim(s, shape, dim), + Self::F32(s) => divide_by_sum_over_dim(s, shape, dim), + Self::F64(s) => divide_by_sum_over_dim(s, shape, dim), + Self::U32(_) => Ok(()), } - Ok(()) } pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> { |