diff options
author | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-07-10 15:21:24 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-10 15:21:24 +0200 |
commit | dc5825967957e28e6ac4f57da18c7963f2be542c (patch) | |
tree | f8249c4d0259c1c8f0c1e46c7f1ecd95da258580 /candle-core/src/tensor.rs | |
parent | 204618b7d37229cd19a7f85ed38e6ab916e1e0d1 (diff) | |
parent | 9a667155fd554fe270561783f6708445e2deb929 (diff) | |
download | candle-dc5825967957e28e6ac4f57da18c7963f2be542c.tar.gz candle-dc5825967957e28e6ac4f57da18c7963f2be542c.tar.bz2 candle-dc5825967957e28e6ac4f57da18c7963f2be542c.zip |
Merge pull request #120 from LaurentMazare/some_doc_plus_fix_stack
Adding some doc + Extended `stack` to work with extra final dimensions.
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 100 |
1 files changed, 97 insertions, 3 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index f0ce18f9..9b0681e0 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -33,6 +33,19 @@ impl AsRef<Tensor> for Tensor { // Storages are also refcounted independently so that its possible to avoid // copying the storage for operations that only modify the shape or stride. #[derive(Clone)] +/// The core struct for manipulating tensors. +/// +/// ```rust +/// use candle::{Tensor, DType, Device}; +/// +/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; +/// let b = Tensor::zeros((3, 4), DType::F32, &Device::Cpu)?; +/// +/// let c = a.matmul(&b)?; +/// # Ok::<(), candle::Error>(()) +/// ``` +/// +/// Tensors are reference counted with [`Arc`] so cloning them is cheap. pub struct Tensor(Arc<Tensor_>); impl std::ops::Deref for Tensor { @@ -126,6 +139,15 @@ impl Tensor { from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape) } + /// Create a new tensors filled with ones + /// + /// ```rust + /// use candle::{Tensor, DType, Device}; + /// let a = Tensor::ones((2, 3), DType::F32, &Device::Cpu)?; + /// let b = Tensor::from_slice(&[1.0f32, 1.0, 1.0, 1.0, 1.0, 1.0], (2, 3), &Device::Cpu)?; + /// // a == b + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> { Self::ones_impl(shape, dtype, device, false) } @@ -136,10 +158,29 @@ impl Tensor { Self::ones_impl(shape, dtype, device, true) } + /// Create a new tensors filled with ones with same shape, dtype, and device + /// as the other tensors + /// + /// ```rust + /// use candle::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = a.ones_like()?; + /// // b == a + 1 + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn ones_like(&self) -> Result<Self> { Tensor::ones(self.shape(), self.dtype(), &self.device()) } + /// Create a new tensors filled with zeros + /// + /// ```rust + /// use candle::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0], (2, 3), &Device::Cpu)?; + /// // a == b + /// # Ok::<(), candle::Error>(()) + /// ``` fn zeros_impl<S: Into<Shape>>( shape: S, dtype: DType, @@ -150,6 +191,15 @@ impl Tensor { from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape) } + /// Create a new tensors filled with zeros + /// + /// ```rust + /// use candle::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0], (2, 3), &Device::Cpu)?; + /// // a == b + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> { Self::zeros_impl(shape, dtype, device, false) } @@ -158,6 +208,16 @@ impl Tensor { Self::zeros_impl(shape, dtype, device, true) } + /// Create a new tensors filled with ones with same shape, dtype, and device + /// as the other tensors + /// + /// ```rust + /// use candle::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = a.zeros_like()?; + /// // b is on CPU f32. + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn zeros_like(&self) -> Result<Self> { Tensor::zeros(self.shape(), self.dtype(), &self.device()) } @@ -187,7 +247,7 @@ impl Tensor { Self::new_impl(array, shape, device, true) } - pub fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>( + fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>( data: Vec<D>, shape: S, device: &Device, @@ -986,11 +1046,28 @@ impl Tensor { self.reshape(dims) } + /// Stacks two or more tensors along a particular dimension. + /// + /// All tensors must have the same rank, and the output has + /// 1 additional rank + /// + /// ```rust + /// # use candle::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// + /// let c = Tensor::stack(&[&a, &b], 0)?; + /// assert_eq!(c.shape().dims(), &[2, 2, 3]); + /// + /// let c = Tensor::stack(&[&a, &b], 2)?; + /// assert_eq!(c.shape().dims(), &[2, 3, 2]); + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn stack<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> { if args.is_empty() { return Err(Error::OpRequiresAtLeastOneTensor { op: "stack" }); } - let dim = dim.to_index(args[0].as_ref().shape(), "stack")?; + let dim = dim.to_index_plus_one(args[0].as_ref().shape(), "stack")?; let args = args .iter() .map(|t| t.as_ref().unsqueeze(dim)) @@ -998,6 +1075,23 @@ impl Tensor { Self::cat(&args, dim) } + /// Concatenates two or more tensors along a particular dimension. + /// + /// All tensors must of the same rank, and the output will have + /// the same rank + /// + /// ```rust + /// # use candle::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// + /// let c = Tensor::cat(&[&a, &b], 0)?; + /// assert_eq!(c.shape().dims(), &[4, 3]); + /// + /// let c = Tensor::cat(&[&a, &b], 1)?; + /// assert_eq!(c.shape().dims(), &[2, 6]); + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> { if args.is_empty() { return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); @@ -1024,7 +1118,7 @@ impl Tensor { } } - pub fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> { + fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> { if args.is_empty() { return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); } |