From 38ac50eeda5a9c77d41f66fb900bdfa67df1fa20 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 10 Jul 2023 14:51:10 +0200 Subject: Adding some doc + Extended `stack` to work with extra final dimensions. --- candle-core/src/shape.rs | 27 ++++++++++ candle-core/src/tensor.rs | 124 +++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 138 insertions(+), 13 deletions(-) (limited to 'candle-core/src') diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index 1152dc3e..632ef116 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -185,6 +185,7 @@ impl Shape { pub trait Dim { fn to_index(&self, shape: &Shape, op: &'static str) -> Result; + fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result; } impl Dim for usize { @@ -200,6 +201,19 @@ impl Dim for usize { Ok(dim) } } + + fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result { + let dim = *self; + if dim > shape.dims().len() { + Err(Error::DimOutOfRange { + shape: shape.clone(), + dim, + op, + })? + } else { + Ok(dim) + } + } } pub enum D { @@ -220,6 +234,19 @@ impl Dim for D { }), } } + + fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result { + let rank = shape.rank(); + match self { + Self::Minus1 if rank >= 1 => Ok(rank), + Self::Minus2 if rank >= 2 => Ok(rank - 1), + _ => Err(Error::DimOutOfRange { + shape: shape.clone(), + dim: 42, // TODO: Have an adequate error + op, + }), + } + } } #[cfg(test)] diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index f0ce18f9..cf8d01e6 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1,3 +1,4 @@ +// #![deny(missing_docs)] use crate::shape::Dim; use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape}; use std::sync::Arc; @@ -33,6 +34,19 @@ impl AsRef for Tensor { // Storages are also refcounted independently so that its possible to avoid // copying the storage for operations that only modify the shape or stride. #[derive(Clone)] +/// The core struct for manipulating tensors. +/// +/// ```rust +/// use candle::{Tensor, DType, Device}; +/// +/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; +/// let b = Tensor::zeros((3, 4), DType::F32, &Device::Cpu)?; +/// +/// let c = a.matmul(&b)?; +/// # Ok::<(), candle::Error>(()) +/// ``` +/// +/// Tensors are reference counted with [`Arc`] so cloning them is cheap. pub struct Tensor(Arc); impl std::ops::Deref for Tensor { @@ -126,20 +140,51 @@ impl Tensor { from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape) } + /// Create a new tensors filled with ones + /// + /// ```rust + /// use candle::{Tensor, DType, Device}; + /// let a = Tensor::ones((2, 3), DType::F32, &Device::Cpu)?; + /// let b = Tensor::from_slice(&[1.0f32, 1.0, 1.0, 1.0, 1.0, 1.0], (2, 3), &Device::Cpu)?; + /// // a == b + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn ones>(shape: S, dtype: DType, device: &Device) -> Result { Self::ones_impl(shape, dtype, device, false) } - pub fn ones_var>(shape: S, dtype: DType, device: &Device) -> Result { - // Maybe we should allocate some actual storage for vars rather than just using a - // broadcasted scalar? - Self::ones_impl(shape, dtype, device, true) - } - + // Hiding it from now, having this functions forces us to have *every* function that creates + // a new tensor potentially `_var` Maybe having something more like `Tensor::ones(..).var()` + // might be easier to check. + // pub fn ones_var>(shape: S, dtype: DType, device: &Device) -> Result { + // // Maybe we should allocate some actual storage for vars rather than just using a + // // broadcasted scalar? + // Self::ones_impl(shape, dtype, device, true) + // } + + /// Create a new tensors filled with ones with same shape, dtype, and device + /// as the other tensors + /// + /// ```rust + /// use candle::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = a.ones_like()?; + /// // b == a + 1 + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn ones_like(&self) -> Result { Tensor::ones(self.shape(), self.dtype(), &self.device()) } + /// Create a new tensors filled with zeros + /// + /// ```rust + /// use candle::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0], (2, 3), &Device::Cpu)?; + /// // a == b + /// # Ok::<(), candle::Error>(()) + /// ``` fn zeros_impl>( shape: S, dtype: DType, @@ -150,14 +195,33 @@ impl Tensor { from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape) } + /// Create a new tensors filled with zeros + /// + /// ```rust + /// use candle::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0], (2, 3), &Device::Cpu)?; + /// // a == b + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn zeros>(shape: S, dtype: DType, device: &Device) -> Result { Self::zeros_impl(shape, dtype, device, false) } - pub fn zeros_var>(shape: S, dtype: DType, device: &Device) -> Result { - Self::zeros_impl(shape, dtype, device, true) - } - + // pub fn zeros_var>(shape: S, dtype: DType, device: &Device) -> Result { + // Self::zeros_impl(shape, dtype, device, true) + // } + + /// Create a new tensors filled with ones with same shape, dtype, and device + /// as the other tensors + /// + /// ```rust + /// use candle::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = a.zeros_like()?; + /// // b is on CPU f32. + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn zeros_like(&self) -> Result { Tensor::zeros(self.shape(), self.dtype(), &self.device()) } @@ -187,7 +251,7 @@ impl Tensor { Self::new_impl(array, shape, device, true) } - pub fn from_vec_impl, D: crate::WithDType>( + fn from_vec_impl, D: crate::WithDType>( data: Vec, shape: S, device: &Device, @@ -986,11 +1050,28 @@ impl Tensor { self.reshape(dims) } + /// Stacks two or more tensors along a particular dimension. + /// + /// All tensors must have the same rank, and the output has + /// 1 additional rank + /// + /// ```rust + /// # use candle::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// + /// let c = Tensor::stack(&[&a, &b], 0)?; + /// assert_eq!(c.shape().dims(), &[2, 2, 3]); + /// + /// let c = Tensor::stack(&[&a, &b], 2)?; + /// assert_eq!(c.shape().dims(), &[2, 3, 2]); + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn stack, D: Dim>(args: &[A], dim: D) -> Result { if args.is_empty() { return Err(Error::OpRequiresAtLeastOneTensor { op: "stack" }); } - let dim = dim.to_index(args[0].as_ref().shape(), "stack")?; + let dim = dim.to_index_plus_one(args[0].as_ref().shape(), "stack")?; let args = args .iter() .map(|t| t.as_ref().unsqueeze(dim)) @@ -998,6 +1079,23 @@ impl Tensor { Self::cat(&args, dim) } + /// Concatenates two or more tensors along a particular dimension. + /// + /// All tensors must of the same rank, and the output will have + /// the same rank + /// + /// ```rust + /// # use candle::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// + /// let c = Tensor::cat(&[&a, &b], 0)?; + /// assert_eq!(c.shape().dims(), &[4, 3]); + /// + /// let c = Tensor::cat(&[&a, &b], 1)?; + /// assert_eq!(c.shape().dims(), &[2, 6]); + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn cat, D: Dim>(args: &[A], dim: D) -> Result { if args.is_empty() { return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); @@ -1024,7 +1122,7 @@ impl Tensor { } } - pub fn cat0>(args: &[A]) -> Result { + fn cat0>(args: &[A]) -> Result { if args.is_empty() { return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); } -- cgit v1.2.3 From 49f4a77ffd715141b50e17abda56a14a1e501886 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 10 Jul 2023 15:11:48 +0200 Subject: Put them back. --- candle-core/src/tensor.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) (limited to 'candle-core/src') diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index cf8d01e6..4a5d9d16 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -156,11 +156,11 @@ impl Tensor { // Hiding it from now, having this functions forces us to have *every* function that creates // a new tensor potentially `_var` Maybe having something more like `Tensor::ones(..).var()` // might be easier to check. - // pub fn ones_var>(shape: S, dtype: DType, device: &Device) -> Result { - // // Maybe we should allocate some actual storage for vars rather than just using a - // // broadcasted scalar? - // Self::ones_impl(shape, dtype, device, true) - // } + pub fn ones_var>(shape: S, dtype: DType, device: &Device) -> Result { + // Maybe we should allocate some actual storage for vars rather than just using a + // broadcasted scalar? + Self::ones_impl(shape, dtype, device, true) + } /// Create a new tensors filled with ones with same shape, dtype, and device /// as the other tensors @@ -208,9 +208,9 @@ impl Tensor { Self::zeros_impl(shape, dtype, device, false) } - // pub fn zeros_var>(shape: S, dtype: DType, device: &Device) -> Result { - // Self::zeros_impl(shape, dtype, device, true) - // } + pub fn zeros_var>(shape: S, dtype: DType, device: &Device) -> Result { + Self::zeros_impl(shape, dtype, device, true) + } /// Create a new tensors filled with ones with same shape, dtype, and device /// as the other tensors -- cgit v1.2.3 From 2c8fbe8155197ed8e3420267c9610fef0d60d62b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 10 Jul 2023 15:13:52 +0200 Subject: oops. --- candle-core/src/tensor.rs | 3 --- 1 file changed, 3 deletions(-) (limited to 'candle-core/src') diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 4a5d9d16..404e3b72 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -153,9 +153,6 @@ impl Tensor { Self::ones_impl(shape, dtype, device, false) } - // Hiding it from now, having this functions forces us to have *every* function that creates - // a new tensor potentially `_var` Maybe having something more like `Tensor::ones(..).var()` - // might be easier to check. pub fn ones_var>(shape: S, dtype: DType, device: &Device) -> Result { // Maybe we should allocate some actual storage for vars rather than just using a // broadcasted scalar? -- cgit v1.2.3 From 9a667155fd554fe270561783f6708445e2deb929 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 10 Jul 2023 15:18:23 +0200 Subject: Removed commented deny --- candle-core/src/tensor.rs | 1 - 1 file changed, 1 deletion(-) (limited to 'candle-core/src') diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 404e3b72..9b0681e0 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1,4 +1,3 @@ -// #![deny(missing_docs)] use crate::shape::Dim; use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape}; use std::sync::Arc; -- cgit v1.2.3