summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2023-06-21 10:25:56 +0100
committerlaurent <laurent.mazare@gmail.com>2023-06-21 10:25:56 +0100
commiteb52b9b343819c547b8e4c47a8ff70cb7c632fbb (patch)
tree0eb1a76af828a00bfa59776e6c30f2f660c18197
parentb3eb57cd0a696ec184e47c7316871b01e0a45aea (diff)
downloadcandle-eb52b9b343819c547b8e4c47a8ff70cb7c632fbb.tar.gz
candle-eb52b9b343819c547b8e4c47a8ff70cb7c632fbb.tar.bz2
candle-eb52b9b343819c547b8e4c47a8ff70cb7c632fbb.zip
Move the cpu backend specific bits apart.
-rw-r--r--src/cpu_backend.rs99
-rw-r--r--src/device.rs5
-rw-r--r--src/lib.rs4
-rw-r--r--src/storage.rs93
4 files changed, 118 insertions, 83 deletions
diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs
new file mode 100644
index 00000000..03068866
--- /dev/null
+++ b/src/cpu_backend.rs
@@ -0,0 +1,99 @@
+use crate::storage::{BinaryOp, UnaryOp};
+use crate::{DType, Error, Result, Shape, StridedIndex};
+
+// TODO: Think about whether we would be better off with a dtype and
+// a buffer as an owned slice of bytes.
+#[derive(Debug, Clone)]
+pub enum CpuStorage {
+ F32(Vec<f32>),
+ F64(Vec<f64>),
+}
+
+impl CpuStorage {
+ pub fn dtype(&self) -> DType {
+ match self {
+ Self::F32(_) => DType::F32,
+ Self::F64(_) => DType::F64,
+ }
+ }
+
+ pub(crate) fn affine_impl(
+ &self,
+ shape: &Shape,
+ stride: &[usize],
+ mul: f64,
+ add: f64,
+ ) -> Result<Self> {
+ match self {
+ Self::F32(storage) => {
+ let index = StridedIndex::new(shape.dims(), stride);
+ let mul = mul as f32;
+ let add = add as f32;
+ let data = index.map(|i| storage[i] * mul + add).collect();
+ Ok(Self::F32(data))
+ }
+ Self::F64(storage) => {
+ let index = StridedIndex::new(shape.dims(), stride);
+ let data = index.map(|i| storage[i] * mul + add).collect();
+ Ok(Self::F64(data))
+ }
+ }
+ }
+
+ pub(crate) fn unary_impl<B: UnaryOp>(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
+ // TODO: Different code path for the contiguous case?
+ match self {
+ Self::F32(storage) => {
+ let index = StridedIndex::new(shape.dims(), stride);
+ let data = index.map(|i| B::f32(storage[i])).collect();
+ Ok(Self::F32(data))
+ }
+ Self::F64(storage) => {
+ let index = StridedIndex::new(shape.dims(), stride);
+ let data = index.map(|i| B::f64(storage[i])).collect();
+ Ok(Self::F64(data))
+ }
+ }
+ }
+
+ pub(crate) fn binary_impl<B: BinaryOp>(
+ &self,
+ rhs: &Self,
+ shape: &Shape,
+ lhs_stride: &[usize],
+ rhs_stride: &[usize],
+ ) -> Result<Self> {
+ // The ggml implementation has different paths based on whether the rhs is contiguous
+ // or not, for now we only consider the general case but we should benchmark and do the
+ // same if it helps.
+ // https://github.com/ggerganov/llama.cpp/blob/aacdbd40562684665b6f7b8ba6695b7a2088bbb0/ggml.c#L7895
+ match (self, rhs) {
+ (CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => {
+ let lhs_index = StridedIndex::new(shape.dims(), lhs_stride);
+ let rhs_index = StridedIndex::new(shape.dims(), rhs_stride);
+ let data = lhs_index
+ .zip(rhs_index)
+ .map(|(lhs_i, rhs_i)| B::f32(lhs[lhs_i], rhs[rhs_i]))
+ .collect();
+ Ok(Self::F32(data))
+ }
+ (CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
+ let lhs_index = StridedIndex::new(shape.dims(), lhs_stride);
+ let rhs_index = StridedIndex::new(shape.dims(), rhs_stride);
+ let data = lhs_index
+ .zip(rhs_index)
+ .map(|(lhs_i, rhs_i)| B::f64(lhs[lhs_i], rhs[rhs_i]))
+ .collect();
+ Ok(Self::F64(data))
+ }
+ _ => {
+ // This should be covered by the dtype check above.
+ Err(Error::DTypeMismatchBinaryOp {
+ lhs: self.dtype(),
+ rhs: rhs.dtype(),
+ op: B::NAME,
+ })
+ }
+ }
+ }
+}
diff --git a/src/device.rs b/src/device.rs
index c092a347..3677cfff 100644
--- a/src/device.rs
+++ b/src/device.rs
@@ -1,7 +1,4 @@
-use crate::{
- storage::{CpuStorage, Storage},
- DType, Result, Shape,
-};
+use crate::{CpuStorage, DType, Result, Shape, Storage};
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum Device {
diff --git a/src/lib.rs b/src/lib.rs
index 58c2ba52..175d36ad 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,3 +1,4 @@
+mod cpu_backend;
mod device;
mod dtype;
mod error;
@@ -7,10 +8,11 @@ mod storage;
mod strided_index;
mod tensor;
+pub use cpu_backend::CpuStorage;
pub use device::Device;
pub use dtype::{DType, WithDType};
pub use error::{Error, Result};
pub use shape::Shape;
-pub use storage::{CpuStorage, Storage};
+pub use storage::Storage;
use strided_index::StridedIndex;
pub use tensor::{Tensor, TensorId};
diff --git a/src/storage.rs b/src/storage.rs
index 30161a2c..7083cc28 100644
--- a/src/storage.rs
+++ b/src/storage.rs
@@ -1,21 +1,4 @@
-use crate::{DType, Device, Error, Result, Shape, StridedIndex};
-
-// TODO: Think about whether we would be better off with a dtype and
-// a buffer as an owned slice of bytes.
-#[derive(Debug, Clone)]
-pub enum CpuStorage {
- F32(Vec<f32>),
- F64(Vec<f64>),
-}
-
-impl CpuStorage {
- pub(crate) fn dtype(&self) -> DType {
- match self {
- Self::F32(_) => DType::F32,
- Self::F64(_) => DType::F64,
- }
- }
-}
+use crate::{CpuStorage, DType, Device, Error, Result, Shape};
#[derive(Debug, Clone)]
pub enum Storage {
@@ -23,13 +6,13 @@ pub enum Storage {
Cuda { gpu_id: usize }, // TODO: Actually add the storage.
}
-trait UnaryOp {
+pub(crate) trait UnaryOp {
const NAME: &'static str;
fn f32(v1: f32) -> f32;
fn f64(v1: f64) -> f64;
}
-trait BinaryOp {
+pub(crate) trait BinaryOp {
const NAME: &'static str;
fn f32(v1: f32, v2: f32) -> f32;
fn f64(v1: f64, v2: f64) -> f64;
@@ -157,20 +140,10 @@ impl Storage {
) -> Result<Self> {
// TODO: Different code path for the contiguous case?
match self {
- Storage::Cpu(storage) => match storage {
- CpuStorage::F32(storage) => {
- let index = StridedIndex::new(shape.dims(), stride);
- let mul = mul as f32;
- let add = add as f32;
- let data = index.map(|i| storage[i] * mul + add).collect();
- Ok(Storage::Cpu(CpuStorage::F32(data)))
- }
- CpuStorage::F64(storage) => {
- let index = StridedIndex::new(shape.dims(), stride);
- let data = index.map(|i| storage[i] * mul + add).collect();
- Ok(Storage::Cpu(CpuStorage::F64(data)))
- }
- },
+ Storage::Cpu(storage) => {
+ let storage = storage.affine_impl(shape, stride, mul, add)?;
+ Ok(Self::Cpu(storage))
+ }
Self::Cuda { .. } => todo!(),
}
}
@@ -178,18 +151,10 @@ impl Storage {
fn unary_impl<B: UnaryOp>(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
// TODO: Different code path for the contiguous case?
match self {
- Storage::Cpu(storage) => match storage {
- CpuStorage::F32(storage) => {
- let index = StridedIndex::new(shape.dims(), stride);
- let data = index.map(|i| B::f32(storage[i])).collect();
- Ok(Storage::Cpu(CpuStorage::F32(data)))
- }
- CpuStorage::F64(storage) => {
- let index = StridedIndex::new(shape.dims(), stride);
- let data = index.map(|i| B::f64(storage[i])).collect();
- Ok(Storage::Cpu(CpuStorage::F64(data)))
- }
- },
+ Storage::Cpu(storage) => {
+ let storage = storage.unary_impl::<B>(shape, stride)?;
+ Ok(Self::Cpu(storage))
+ }
Self::Cuda { .. } => todo!(),
}
}
@@ -204,39 +169,11 @@ impl Storage {
) -> Result<Self> {
self.same_device(rhs, B::NAME)?;
self.same_dtype(rhs, B::NAME)?;
- // The ggml implementation has different paths based on whether the rhs is contiguous
- // or not, for now we only consider the general case but we should benchmark and do the
- // same if it helps.
- // https://github.com/ggerganov/llama.cpp/blob/aacdbd40562684665b6f7b8ba6695b7a2088bbb0/ggml.c#L7895
match (self, rhs) {
- (Storage::Cpu(lhs), Storage::Cpu(rhs)) => match (lhs, rhs) {
- (CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => {
- let lhs_index = StridedIndex::new(shape.dims(), lhs_stride);
- let rhs_index = StridedIndex::new(shape.dims(), rhs_stride);
- let data = lhs_index
- .zip(rhs_index)
- .map(|(lhs_i, rhs_i)| B::f32(lhs[lhs_i], rhs[rhs_i]))
- .collect();
- Ok(Storage::Cpu(CpuStorage::F32(data)))
- }
- (CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
- let lhs_index = StridedIndex::new(shape.dims(), lhs_stride);
- let rhs_index = StridedIndex::new(shape.dims(), rhs_stride);
- let data = lhs_index
- .zip(rhs_index)
- .map(|(lhs_i, rhs_i)| B::f64(lhs[lhs_i], rhs[rhs_i]))
- .collect();
- Ok(Storage::Cpu(CpuStorage::F64(data)))
- }
- _ => {
- // This should be covered by the dtype check above.
- Err(Error::DTypeMismatchBinaryOp {
- lhs: lhs.dtype(),
- rhs: rhs.dtype(),
- op: B::NAME,
- })
- }
- },
+ (Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
+ let storage = lhs.binary_impl::<B>(rhs, shape, lhs_stride, rhs_stride)?;
+ Ok(Self::Cpu(storage))
+ }
(Self::Cuda { .. }, Self::Cuda { .. }) => todo!(),
(lhs, rhs) => {
// Should not happen because of the same device check above but we're defensive