diff options
-rw-r--r-- | candle-core/src/shape.rs | 27 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 100 |
2 files changed, 124 insertions, 3 deletions
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index 1152dc3e..632ef116 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -185,6 +185,7 @@ impl Shape { pub trait Dim { fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize>; + fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize>; } impl Dim for usize { @@ -200,6 +201,19 @@ impl Dim for usize { Ok(dim) } } + + fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> { + let dim = *self; + if dim > shape.dims().len() { + Err(Error::DimOutOfRange { + shape: shape.clone(), + dim, + op, + })? + } else { + Ok(dim) + } + } } pub enum D { @@ -220,6 +234,19 @@ impl Dim for D { }), } } + + fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> { + let rank = shape.rank(); + match self { + Self::Minus1 if rank >= 1 => Ok(rank), + Self::Minus2 if rank >= 2 => Ok(rank - 1), + _ => Err(Error::DimOutOfRange { + shape: shape.clone(), + dim: 42, // TODO: Have an adequate error + op, + }), + } + } } #[cfg(test)] diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index f0ce18f9..9b0681e0 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -33,6 +33,19 @@ impl AsRef<Tensor> for Tensor { // Storages are also refcounted independently so that its possible to avoid // copying the storage for operations that only modify the shape or stride. #[derive(Clone)] +/// The core struct for manipulating tensors. +/// +/// ```rust +/// use candle::{Tensor, DType, Device}; +/// +/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; +/// let b = Tensor::zeros((3, 4), DType::F32, &Device::Cpu)?; +/// +/// let c = a.matmul(&b)?; +/// # Ok::<(), candle::Error>(()) +/// ``` +/// +/// Tensors are reference counted with [`Arc`] so cloning them is cheap. pub struct Tensor(Arc<Tensor_>); impl std::ops::Deref for Tensor { @@ -126,6 +139,15 @@ impl Tensor { from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape) } + /// Create a new tensors filled with ones + /// + /// ```rust + /// use candle::{Tensor, DType, Device}; + /// let a = Tensor::ones((2, 3), DType::F32, &Device::Cpu)?; + /// let b = Tensor::from_slice(&[1.0f32, 1.0, 1.0, 1.0, 1.0, 1.0], (2, 3), &Device::Cpu)?; + /// // a == b + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> { Self::ones_impl(shape, dtype, device, false) } @@ -136,10 +158,29 @@ impl Tensor { Self::ones_impl(shape, dtype, device, true) } + /// Create a new tensors filled with ones with same shape, dtype, and device + /// as the other tensors + /// + /// ```rust + /// use candle::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = a.ones_like()?; + /// // b == a + 1 + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn ones_like(&self) -> Result<Self> { Tensor::ones(self.shape(), self.dtype(), &self.device()) } + /// Create a new tensors filled with zeros + /// + /// ```rust + /// use candle::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0], (2, 3), &Device::Cpu)?; + /// // a == b + /// # Ok::<(), candle::Error>(()) + /// ``` fn zeros_impl<S: Into<Shape>>( shape: S, dtype: DType, @@ -150,6 +191,15 @@ impl Tensor { from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape) } + /// Create a new tensors filled with zeros + /// + /// ```rust + /// use candle::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0], (2, 3), &Device::Cpu)?; + /// // a == b + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> { Self::zeros_impl(shape, dtype, device, false) } @@ -158,6 +208,16 @@ impl Tensor { Self::zeros_impl(shape, dtype, device, true) } + /// Create a new tensors filled with ones with same shape, dtype, and device + /// as the other tensors + /// + /// ```rust + /// use candle::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = a.zeros_like()?; + /// // b is on CPU f32. + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn zeros_like(&self) -> Result<Self> { Tensor::zeros(self.shape(), self.dtype(), &self.device()) } @@ -187,7 +247,7 @@ impl Tensor { Self::new_impl(array, shape, device, true) } - pub fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>( + fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>( data: Vec<D>, shape: S, device: &Device, @@ -986,11 +1046,28 @@ impl Tensor { self.reshape(dims) } + /// Stacks two or more tensors along a particular dimension. + /// + /// All tensors must have the same rank, and the output has + /// 1 additional rank + /// + /// ```rust + /// # use candle::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// + /// let c = Tensor::stack(&[&a, &b], 0)?; + /// assert_eq!(c.shape().dims(), &[2, 2, 3]); + /// + /// let c = Tensor::stack(&[&a, &b], 2)?; + /// assert_eq!(c.shape().dims(), &[2, 3, 2]); + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn stack<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> { if args.is_empty() { return Err(Error::OpRequiresAtLeastOneTensor { op: "stack" }); } - let dim = dim.to_index(args[0].as_ref().shape(), "stack")?; + let dim = dim.to_index_plus_one(args[0].as_ref().shape(), "stack")?; let args = args .iter() .map(|t| t.as_ref().unsqueeze(dim)) @@ -998,6 +1075,23 @@ impl Tensor { Self::cat(&args, dim) } + /// Concatenates two or more tensors along a particular dimension. + /// + /// All tensors must of the same rank, and the output will have + /// the same rank + /// + /// ```rust + /// # use candle::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// + /// let c = Tensor::cat(&[&a, &b], 0)?; + /// assert_eq!(c.shape().dims(), &[4, 3]); + /// + /// let c = Tensor::cat(&[&a, &b], 1)?; + /// assert_eq!(c.shape().dims(), &[2, 6]); + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> { if args.is_empty() { return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); @@ -1024,7 +1118,7 @@ impl Tensor { } } - pub fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> { + fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> { if args.is_empty() { return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); } |