diff options
author | laurent <laurent.mazare@gmail.com> | 2023-06-21 10:25:56 +0100 |
---|---|---|
committer | laurent <laurent.mazare@gmail.com> | 2023-06-21 10:25:56 +0100 |
commit | eb52b9b343819c547b8e4c47a8ff70cb7c632fbb (patch) | |
tree | 0eb1a76af828a00bfa59776e6c30f2f660c18197 /src/cpu_backend.rs | |
parent | b3eb57cd0a696ec184e47c7316871b01e0a45aea (diff) | |
download | candle-eb52b9b343819c547b8e4c47a8ff70cb7c632fbb.tar.gz candle-eb52b9b343819c547b8e4c47a8ff70cb7c632fbb.tar.bz2 candle-eb52b9b343819c547b8e4c47a8ff70cb7c632fbb.zip |
Move the cpu backend specific bits apart.
Diffstat (limited to 'src/cpu_backend.rs')
-rw-r--r-- | src/cpu_backend.rs | 99 |
1 files changed, 99 insertions, 0 deletions
diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs new file mode 100644 index 00000000..03068866 --- /dev/null +++ b/src/cpu_backend.rs @@ -0,0 +1,99 @@ +use crate::storage::{BinaryOp, UnaryOp}; +use crate::{DType, Error, Result, Shape, StridedIndex}; + +// TODO: Think about whether we would be better off with a dtype and +// a buffer as an owned slice of bytes. +#[derive(Debug, Clone)] +pub enum CpuStorage { + F32(Vec<f32>), + F64(Vec<f64>), +} + +impl CpuStorage { + pub fn dtype(&self) -> DType { + match self { + Self::F32(_) => DType::F32, + Self::F64(_) => DType::F64, + } + } + + pub(crate) fn affine_impl( + &self, + shape: &Shape, + stride: &[usize], + mul: f64, + add: f64, + ) -> Result<Self> { + match self { + Self::F32(storage) => { + let index = StridedIndex::new(shape.dims(), stride); + let mul = mul as f32; + let add = add as f32; + let data = index.map(|i| storage[i] * mul + add).collect(); + Ok(Self::F32(data)) + } + Self::F64(storage) => { + let index = StridedIndex::new(shape.dims(), stride); + let data = index.map(|i| storage[i] * mul + add).collect(); + Ok(Self::F64(data)) + } + } + } + + pub(crate) fn unary_impl<B: UnaryOp>(&self, shape: &Shape, stride: &[usize]) -> Result<Self> { + // TODO: Different code path for the contiguous case? + match self { + Self::F32(storage) => { + let index = StridedIndex::new(shape.dims(), stride); + let data = index.map(|i| B::f32(storage[i])).collect(); + Ok(Self::F32(data)) + } + Self::F64(storage) => { + let index = StridedIndex::new(shape.dims(), stride); + let data = index.map(|i| B::f64(storage[i])).collect(); + Ok(Self::F64(data)) + } + } + } + + pub(crate) fn binary_impl<B: BinaryOp>( + &self, + rhs: &Self, + shape: &Shape, + lhs_stride: &[usize], + rhs_stride: &[usize], + ) -> Result<Self> { + // The ggml implementation has different paths based on whether the rhs is contiguous + // or not, for now we only consider the general case but we should benchmark and do the + // same if it helps. + // https://github.com/ggerganov/llama.cpp/blob/aacdbd40562684665b6f7b8ba6695b7a2088bbb0/ggml.c#L7895 + match (self, rhs) { + (CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => { + let lhs_index = StridedIndex::new(shape.dims(), lhs_stride); + let rhs_index = StridedIndex::new(shape.dims(), rhs_stride); + let data = lhs_index + .zip(rhs_index) + .map(|(lhs_i, rhs_i)| B::f32(lhs[lhs_i], rhs[rhs_i])) + .collect(); + Ok(Self::F32(data)) + } + (CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => { + let lhs_index = StridedIndex::new(shape.dims(), lhs_stride); + let rhs_index = StridedIndex::new(shape.dims(), rhs_stride); + let data = lhs_index + .zip(rhs_index) + .map(|(lhs_i, rhs_i)| B::f64(lhs[lhs_i], rhs[rhs_i])) + .collect(); + Ok(Self::F64(data)) + } + _ => { + // This should be covered by the dtype check above. + Err(Error::DTypeMismatchBinaryOp { + lhs: self.dtype(), + rhs: rhs.dtype(), + op: B::NAME, + }) + } + } + } +} |