summaryrefslogtreecommitdiff
path: root/candle-core/src/cpu_backend.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/cpu_backend.rs')
-rw-r--r--candle-core/src/cpu_backend.rs111
1 files changed, 34 insertions, 77 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 7170e470..136eeaba 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -292,6 +292,35 @@ impl Map2 for MatMul {
}
}
+fn divide_by_sum_over_dim<T: WithDType + num_traits::NumAssign>(
+ s: &mut [T],
+ shape: &Shape,
+ dim: usize,
+) -> Result<()> {
+ // [self] stores data in a contiguous way starting at offset 0.
+ let dims = shape.dims();
+ let elem_per_slice = dims[dim];
+ let prod_pre_dim = dims[..dim].iter().product();
+ let prod_post_dim = dims[dim + 1..].iter().product();
+ for pre_idx in 0..prod_pre_dim {
+ for post_idx in 0..prod_post_dim {
+ let mut sum = 0f64;
+ let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
+ for _ in 0..elem_per_slice {
+ sum += s[idx].to_f64();
+ idx += prod_post_dim
+ }
+ let sum = T::from_f64(sum);
+ let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
+ for _ in 0..elem_per_slice {
+ s[idx] /= sum;
+ idx += prod_post_dim
+ }
+ }
+ }
+ Ok(())
+}
+
impl CpuStorage {
pub fn dtype(&self) -> DType {
match self {
@@ -437,85 +466,13 @@ impl CpuStorage {
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
// [self] stores data in a contiguous way starting at offset 0.
- let dims = shape.dims();
- let elem_per_slice = dims[dim];
- let prod_pre_dim = dims[..dim].iter().product();
- let prod_post_dim = dims[dim + 1..].iter().product();
match self {
- Self::BF16(storage) => {
- for pre_idx in 0..prod_pre_dim {
- for post_idx in 0..prod_post_dim {
- let mut sum = 0f64;
- let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
- for _ in 0..elem_per_slice {
- sum += storage[idx].to_f64();
- idx += prod_post_dim
- }
- let sum = bf16::from_f64(sum);
- let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
- for _ in 0..elem_per_slice {
- storage[idx] /= sum;
- idx += prod_post_dim
- }
- }
- }
- }
- Self::F16(storage) => {
- for pre_idx in 0..prod_pre_dim {
- for post_idx in 0..prod_post_dim {
- let mut sum = 0f64;
- let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
- for _ in 0..elem_per_slice {
- sum += storage[idx].to_f64();
- idx += prod_post_dim
- }
- let sum = f16::from_f64(sum);
- let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
- for _ in 0..elem_per_slice {
- storage[idx] /= sum;
- idx += prod_post_dim
- }
- }
- }
- }
- Self::F32(storage) => {
- for pre_idx in 0..prod_pre_dim {
- for post_idx in 0..prod_post_dim {
- let mut sum = 0f64;
- let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
- for _ in 0..elem_per_slice {
- sum += storage[idx] as f64;
- idx += prod_post_dim
- }
- let sum = sum as f32;
- let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
- for _ in 0..elem_per_slice {
- storage[idx] /= sum;
- idx += prod_post_dim
- }
- }
- }
- }
- Self::F64(storage) => {
- for pre_idx in 0..prod_pre_dim {
- for post_idx in 0..prod_post_dim {
- let mut sum = 0f64;
- let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
- for _ in 0..elem_per_slice {
- sum += storage[idx];
- idx += prod_post_dim
- }
- let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
- for _ in 0..elem_per_slice {
- storage[idx] /= sum;
- idx += prod_post_dim
- }
- }
- }
- }
- Self::U32(_) => {}
+ Self::BF16(s) => divide_by_sum_over_dim(s, shape, dim),
+ Self::F16(s) => divide_by_sum_over_dim(s, shape, dim),
+ Self::F32(s) => divide_by_sum_over_dim(s, shape, dim),
+ Self::F64(s) => divide_by_sum_over_dim(s, shape, dim),
+ Self::U32(_) => Ok(()),
}
- Ok(())
}
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {