diff options
Diffstat (limited to 'candle-core/src/storage.rs')
-rw-r--r-- | candle-core/src/storage.rs | 94 |
1 files changed, 35 insertions, 59 deletions
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index e44a2db6..7acf6dd0 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -1,4 +1,4 @@ -use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Result, Shape}; +use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape}; // We do not want to implement Clone on Storage as cloning may fail because of // out of memory. Instead try_clone should be used. @@ -53,38 +53,33 @@ impl Storage { } } - pub(crate) fn affine_impl( - &self, - shape: &Shape, - stride: &[usize], - mul: f64, - add: f64, - ) -> Result<Self> { + pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> { match self { Storage::Cpu(storage) => { - let storage = storage.affine_impl(shape, stride, mul, add)?; + let storage = storage.affine(layout, mul, add)?; Ok(Self::Cpu(storage)) } Self::Cuda(storage) => { - let storage = storage.affine_impl(shape, stride, mul, add)?; + let storage = storage.affine(layout, mul, add)?; Ok(Self::Cuda(storage)) } } } - pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], s: &[usize]) -> Result<Self> { + pub(crate) fn sum(&self, layout: &Layout, s: &[usize]) -> Result<Self> { match self { Storage::Cpu(storage) => { - let storage = storage.sum(shape, stride, s)?; + let storage = storage.sum(layout, s)?; Ok(Self::Cpu(storage)) } Self::Cuda(storage) => { - let storage = storage.sum(shape, stride, s)?; + let storage = storage.sum(layout, s)?; Ok(Self::Cuda(storage)) } } } + // This assumes a contiguous layout and no offset. pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> { match self { Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?, @@ -93,32 +88,28 @@ impl Storage { Ok(()) } - pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> { + pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> { match self { Storage::Cpu(storage) => { - let storage = storage.to_dtype(shape, stride, dtype)?; + let storage = storage.to_dtype(layout, dtype)?; Ok(Self::Cpu(storage)) } Self::Cuda(storage) => { - let storage = storage.to_dtype(shape, stride, dtype)?; + let storage = storage.to_dtype(layout, dtype)?; Ok(Self::Cuda(storage)) } } } - pub(crate) fn unary_impl<B: op::UnaryOp>( - &self, - shape: &Shape, - stride: &[usize], - ) -> Result<Self> { + pub(crate) fn unary_impl<B: op::UnaryOp>(&self, layout: &Layout) -> Result<Self> { // TODO: Different code path for the contiguous case? match self { Storage::Cpu(storage) => { - let storage = storage.unary_impl::<B>(shape, stride)?; + let storage = storage.unary_impl::<B>(layout)?; Ok(Self::Cpu(storage)) } Self::Cuda(storage) => { - let storage = storage.unary_impl::<B>(shape, stride)?; + let storage = storage.unary_impl::<B>(layout)?; Ok(Self::Cuda(storage)) } } @@ -127,19 +118,18 @@ impl Storage { pub(crate) fn binary_impl<B: op::BinaryOp>( &self, rhs: &Self, - shape: &Shape, - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_layout: &Layout, + rhs_layout: &Layout, ) -> Result<Self> { self.same_device(rhs, B::NAME)?; self.same_dtype(rhs, B::NAME)?; match (self, rhs) { (Storage::Cpu(lhs), Storage::Cpu(rhs)) => { - let storage = lhs.binary_impl::<B>(rhs, shape, lhs_stride, rhs_stride)?; + let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?; Ok(Self::Cpu(storage)) } (Self::Cuda(lhs), Self::Cuda(rhs)) => { - let storage = lhs.binary_impl::<B>(rhs, shape, lhs_stride, rhs_stride)?; + let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?; Ok(Self::Cuda(storage)) } (lhs, rhs) => { @@ -156,49 +146,41 @@ impl Storage { pub(crate) fn where_cond( &self, - shape: &Shape, - stride: &[usize], + layout: &Layout, t: &Self, - stride_t: &[usize], + layout_t: &Layout, f: &Self, - stride_f: &[usize], + layout_f: &Layout, ) -> Result<Self> { self.same_device(t, "where")?; self.same_device(f, "where")?; t.same_dtype(f, "where")?; match (self, t, f) { (Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => { - let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?; + let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?; Ok(Self::Cpu(storage)) } (Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => { - let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?; + let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?; Ok(Self::Cuda(storage)) } (_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), rhs: rhs.device().location(), - op: "embedding", + op: "where", }), } } - pub(crate) fn embedding_impl( - &self, - shape: &Shape, - stride: &[usize], - rhs: &Self, - hidden_size: usize, - vocab_size: usize, - ) -> Result<Self> { + pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> { self.same_device(rhs, "embedding")?; match (self, rhs) { (Storage::Cpu(lhs), Storage::Cpu(rhs)) => { - let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?; + let storage = lhs.embedding(layout, rhs, rhs_l)?; Ok(Self::Cpu(storage)) } (Self::Cuda(lhs), Self::Cuda(rhs)) => { - let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?; + let storage = lhs.embedding(layout, rhs, rhs_l)?; Ok(Self::Cuda(storage)) } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { @@ -209,22 +191,22 @@ impl Storage { } } - pub(crate) fn matmul_impl( + pub(crate) fn matmul( &self, rhs: &Self, bmnk: (usize, usize, usize, usize), - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_layout: &Layout, + rhs_layout: &Layout, ) -> Result<Self> { self.same_device(rhs, "matmul")?; self.same_dtype(rhs, "matmul")?; match (self, rhs) { (Self::Cpu(lhs), Self::Cpu(rhs)) => { - let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?; + let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?; Ok(Self::Cpu(storage)) } (Self::Cuda(lhs), Self::Cuda(rhs)) => { - let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?; + let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?; Ok(Self::Cuda(storage)) } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { @@ -240,17 +222,11 @@ impl Storage { &self, dst: &mut Self, dst_offset: usize, - src_shape: &Shape, - src_stride: &[usize], - src_offset: usize, + src_l: &Layout, ) -> Result<()> { match (self, dst) { - (Self::Cpu(src), Self::Cpu(dst)) => { - src.copy_strided_src(dst, dst_offset, src_shape, src_stride, src_offset) - } - (Self::Cuda(src), Self::Cuda(dst)) => { - Ok(src.copy_strided_src(dst, dst_offset, src_shape, src_stride, src_offset)?) - } + (Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l), + (Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?), (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), rhs: rhs.device().location(), |