summaryrefslogtreecommitdiff
path: root/src/cpu_backend.rs
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2023-06-21 10:25:56 +0100
committerlaurent <laurent.mazare@gmail.com>2023-06-21 10:25:56 +0100
commiteb52b9b343819c547b8e4c47a8ff70cb7c632fbb (patch)
tree0eb1a76af828a00bfa59776e6c30f2f660c18197 /src/cpu_backend.rs
parentb3eb57cd0a696ec184e47c7316871b01e0a45aea (diff)
downloadcandle-eb52b9b343819c547b8e4c47a8ff70cb7c632fbb.tar.gz
candle-eb52b9b343819c547b8e4c47a8ff70cb7c632fbb.tar.bz2
candle-eb52b9b343819c547b8e4c47a8ff70cb7c632fbb.zip
Move the cpu backend specific bits apart.
Diffstat (limited to 'src/cpu_backend.rs')
-rw-r--r--src/cpu_backend.rs99
1 files changed, 99 insertions, 0 deletions
diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs
new file mode 100644
index 00000000..03068866
--- /dev/null
+++ b/src/cpu_backend.rs
@@ -0,0 +1,99 @@
+use crate::storage::{BinaryOp, UnaryOp};
+use crate::{DType, Error, Result, Shape, StridedIndex};
+
+// TODO: Think about whether we would be better off with a dtype and
+// a buffer as an owned slice of bytes.
+#[derive(Debug, Clone)]
+pub enum CpuStorage {
+ F32(Vec<f32>),
+ F64(Vec<f64>),
+}
+
+impl CpuStorage {
+ pub fn dtype(&self) -> DType {
+ match self {
+ Self::F32(_) => DType::F32,
+ Self::F64(_) => DType::F64,
+ }
+ }
+
+ pub(crate) fn affine_impl(
+ &self,
+ shape: &Shape,
+ stride: &[usize],
+ mul: f64,
+ add: f64,
+ ) -> Result<Self> {
+ match self {
+ Self::F32(storage) => {
+ let index = StridedIndex::new(shape.dims(), stride);
+ let mul = mul as f32;
+ let add = add as f32;
+ let data = index.map(|i| storage[i] * mul + add).collect();
+ Ok(Self::F32(data))
+ }
+ Self::F64(storage) => {
+ let index = StridedIndex::new(shape.dims(), stride);
+ let data = index.map(|i| storage[i] * mul + add).collect();
+ Ok(Self::F64(data))
+ }
+ }
+ }
+
+ pub(crate) fn unary_impl<B: UnaryOp>(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
+ // TODO: Different code path for the contiguous case?
+ match self {
+ Self::F32(storage) => {
+ let index = StridedIndex::new(shape.dims(), stride);
+ let data = index.map(|i| B::f32(storage[i])).collect();
+ Ok(Self::F32(data))
+ }
+ Self::F64(storage) => {
+ let index = StridedIndex::new(shape.dims(), stride);
+ let data = index.map(|i| B::f64(storage[i])).collect();
+ Ok(Self::F64(data))
+ }
+ }
+ }
+
+ pub(crate) fn binary_impl<B: BinaryOp>(
+ &self,
+ rhs: &Self,
+ shape: &Shape,
+ lhs_stride: &[usize],
+ rhs_stride: &[usize],
+ ) -> Result<Self> {
+ // The ggml implementation has different paths based on whether the rhs is contiguous
+ // or not, for now we only consider the general case but we should benchmark and do the
+ // same if it helps.
+ // https://github.com/ggerganov/llama.cpp/blob/aacdbd40562684665b6f7b8ba6695b7a2088bbb0/ggml.c#L7895
+ match (self, rhs) {
+ (CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => {
+ let lhs_index = StridedIndex::new(shape.dims(), lhs_stride);
+ let rhs_index = StridedIndex::new(shape.dims(), rhs_stride);
+ let data = lhs_index
+ .zip(rhs_index)
+ .map(|(lhs_i, rhs_i)| B::f32(lhs[lhs_i], rhs[rhs_i]))
+ .collect();
+ Ok(Self::F32(data))
+ }
+ (CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
+ let lhs_index = StridedIndex::new(shape.dims(), lhs_stride);
+ let rhs_index = StridedIndex::new(shape.dims(), rhs_stride);
+ let data = lhs_index
+ .zip(rhs_index)
+ .map(|(lhs_i, rhs_i)| B::f64(lhs[lhs_i], rhs[rhs_i]))
+ .collect();
+ Ok(Self::F64(data))
+ }
+ _ => {
+ // This should be covered by the dtype check above.
+ Err(Error::DTypeMismatchBinaryOp {
+ lhs: self.dtype(),
+ rhs: rhs.dtype(),
+ op: B::NAME,
+ })
+ }
+ }
+ }
+}