summaryrefslogtreecommitdiff
path: root/candle-core/src/storage.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/storage.rs')
-rw-r--r--candle-core/src/storage.rs94
1 files changed, 35 insertions, 59 deletions
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs
index e44a2db6..7acf6dd0 100644
--- a/candle-core/src/storage.rs
+++ b/candle-core/src/storage.rs
@@ -1,4 +1,4 @@
-use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Result, Shape};
+use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
// We do not want to implement Clone on Storage as cloning may fail because of
// out of memory. Instead try_clone should be used.
@@ -53,38 +53,33 @@ impl Storage {
}
}
- pub(crate) fn affine_impl(
- &self,
- shape: &Shape,
- stride: &[usize],
- mul: f64,
- add: f64,
- ) -> Result<Self> {
+ pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
- let storage = storage.affine_impl(shape, stride, mul, add)?;
+ let storage = storage.affine(layout, mul, add)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
- let storage = storage.affine_impl(shape, stride, mul, add)?;
+ let storage = storage.affine(layout, mul, add)?;
Ok(Self::Cuda(storage))
}
}
}
- pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], s: &[usize]) -> Result<Self> {
+ pub(crate) fn sum(&self, layout: &Layout, s: &[usize]) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
- let storage = storage.sum(shape, stride, s)?;
+ let storage = storage.sum(layout, s)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
- let storage = storage.sum(shape, stride, s)?;
+ let storage = storage.sum(layout, s)?;
Ok(Self::Cuda(storage))
}
}
}
+ // This assumes a contiguous layout and no offset.
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
match self {
Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
@@ -93,32 +88,28 @@ impl Storage {
Ok(())
}
- pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
+ pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
- let storage = storage.to_dtype(shape, stride, dtype)?;
+ let storage = storage.to_dtype(layout, dtype)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
- let storage = storage.to_dtype(shape, stride, dtype)?;
+ let storage = storage.to_dtype(layout, dtype)?;
Ok(Self::Cuda(storage))
}
}
}
- pub(crate) fn unary_impl<B: op::UnaryOp>(
- &self,
- shape: &Shape,
- stride: &[usize],
- ) -> Result<Self> {
+ pub(crate) fn unary_impl<B: op::UnaryOp>(&self, layout: &Layout) -> Result<Self> {
// TODO: Different code path for the contiguous case?
match self {
Storage::Cpu(storage) => {
- let storage = storage.unary_impl::<B>(shape, stride)?;
+ let storage = storage.unary_impl::<B>(layout)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
- let storage = storage.unary_impl::<B>(shape, stride)?;
+ let storage = storage.unary_impl::<B>(layout)?;
Ok(Self::Cuda(storage))
}
}
@@ -127,19 +118,18 @@ impl Storage {
pub(crate) fn binary_impl<B: op::BinaryOp>(
&self,
rhs: &Self,
- shape: &Shape,
- lhs_stride: &[usize],
- rhs_stride: &[usize],
+ lhs_layout: &Layout,
+ rhs_layout: &Layout,
) -> Result<Self> {
self.same_device(rhs, B::NAME)?;
self.same_dtype(rhs, B::NAME)?;
match (self, rhs) {
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
- let storage = lhs.binary_impl::<B>(rhs, shape, lhs_stride, rhs_stride)?;
+ let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
- let storage = lhs.binary_impl::<B>(rhs, shape, lhs_stride, rhs_stride)?;
+ let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
Ok(Self::Cuda(storage))
}
(lhs, rhs) => {
@@ -156,49 +146,41 @@ impl Storage {
pub(crate) fn where_cond(
&self,
- shape: &Shape,
- stride: &[usize],
+ layout: &Layout,
t: &Self,
- stride_t: &[usize],
+ layout_t: &Layout,
f: &Self,
- stride_f: &[usize],
+ layout_f: &Layout,
) -> Result<Self> {
self.same_device(t, "where")?;
self.same_device(f, "where")?;
t.same_dtype(f, "where")?;
match (self, t, f) {
(Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => {
- let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?;
+ let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => {
- let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?;
+ let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
Ok(Self::Cuda(storage))
}
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
- op: "embedding",
+ op: "where",
}),
}
}
- pub(crate) fn embedding_impl(
- &self,
- shape: &Shape,
- stride: &[usize],
- rhs: &Self,
- hidden_size: usize,
- vocab_size: usize,
- ) -> Result<Self> {
+ pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
self.same_device(rhs, "embedding")?;
match (self, rhs) {
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
- let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?;
+ let storage = lhs.embedding(layout, rhs, rhs_l)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
- let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?;
+ let storage = lhs.embedding(layout, rhs, rhs_l)?;
Ok(Self::Cuda(storage))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
@@ -209,22 +191,22 @@ impl Storage {
}
}
- pub(crate) fn matmul_impl(
+ pub(crate) fn matmul(
&self,
rhs: &Self,
bmnk: (usize, usize, usize, usize),
- lhs_stride: &[usize],
- rhs_stride: &[usize],
+ lhs_layout: &Layout,
+ rhs_layout: &Layout,
) -> Result<Self> {
self.same_device(rhs, "matmul")?;
self.same_dtype(rhs, "matmul")?;
match (self, rhs) {
(Self::Cpu(lhs), Self::Cpu(rhs)) => {
- let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?;
+ let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
- let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?;
+ let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
Ok(Self::Cuda(storage))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
@@ -240,17 +222,11 @@ impl Storage {
&self,
dst: &mut Self,
dst_offset: usize,
- src_shape: &Shape,
- src_stride: &[usize],
- src_offset: usize,
+ src_l: &Layout,
) -> Result<()> {
match (self, dst) {
- (Self::Cpu(src), Self::Cpu(dst)) => {
- src.copy_strided_src(dst, dst_offset, src_shape, src_stride, src_offset)
- }
- (Self::Cuda(src), Self::Cuda(dst)) => {
- Ok(src.copy_strided_src(dst, dst_offset, src_shape, src_stride, src_offset)?)
- }
+ (Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
+ (Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),