summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/shape.rs27
-rw-r--r--candle-core/src/tensor.rs100
2 files changed, 124 insertions, 3 deletions
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs
index 1152dc3e..632ef116 100644
--- a/candle-core/src/shape.rs
+++ b/candle-core/src/shape.rs
@@ -185,6 +185,7 @@ impl Shape {
pub trait Dim {
fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize>;
+ fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize>;
}
impl Dim for usize {
@@ -200,6 +201,19 @@ impl Dim for usize {
Ok(dim)
}
}
+
+ fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
+ let dim = *self;
+ if dim > shape.dims().len() {
+ Err(Error::DimOutOfRange {
+ shape: shape.clone(),
+ dim,
+ op,
+ })?
+ } else {
+ Ok(dim)
+ }
+ }
}
pub enum D {
@@ -220,6 +234,19 @@ impl Dim for D {
}),
}
}
+
+ fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
+ let rank = shape.rank();
+ match self {
+ Self::Minus1 if rank >= 1 => Ok(rank),
+ Self::Minus2 if rank >= 2 => Ok(rank - 1),
+ _ => Err(Error::DimOutOfRange {
+ shape: shape.clone(),
+ dim: 42, // TODO: Have an adequate error
+ op,
+ }),
+ }
+ }
}
#[cfg(test)]
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index f0ce18f9..9b0681e0 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -33,6 +33,19 @@ impl AsRef<Tensor> for Tensor {
// Storages are also refcounted independently so that its possible to avoid
// copying the storage for operations that only modify the shape or stride.
#[derive(Clone)]
+/// The core struct for manipulating tensors.
+///
+/// ```rust
+/// use candle::{Tensor, DType, Device};
+///
+/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
+/// let b = Tensor::zeros((3, 4), DType::F32, &Device::Cpu)?;
+///
+/// let c = a.matmul(&b)?;
+/// # Ok::<(), candle::Error>(())
+/// ```
+///
+/// Tensors are reference counted with [`Arc`] so cloning them is cheap.
pub struct Tensor(Arc<Tensor_>);
impl std::ops::Deref for Tensor {
@@ -126,6 +139,15 @@ impl Tensor {
from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
}
+ /// Create a new tensors filled with ones
+ ///
+ /// ```rust
+ /// use candle::{Tensor, DType, Device};
+ /// let a = Tensor::ones((2, 3), DType::F32, &Device::Cpu)?;
+ /// let b = Tensor::from_slice(&[1.0f32, 1.0, 1.0, 1.0, 1.0, 1.0], (2, 3), &Device::Cpu)?;
+ /// // a == b
+ /// # Ok::<(), candle::Error>(())
+ /// ```
pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
Self::ones_impl(shape, dtype, device, false)
}
@@ -136,10 +158,29 @@ impl Tensor {
Self::ones_impl(shape, dtype, device, true)
}
+ /// Create a new tensors filled with ones with same shape, dtype, and device
+ /// as the other tensors
+ ///
+ /// ```rust
+ /// use candle::{Tensor, DType, Device};
+ /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
+ /// let b = a.ones_like()?;
+ /// // b == a + 1
+ /// # Ok::<(), candle::Error>(())
+ /// ```
pub fn ones_like(&self) -> Result<Self> {
Tensor::ones(self.shape(), self.dtype(), &self.device())
}
+ /// Create a new tensors filled with zeros
+ ///
+ /// ```rust
+ /// use candle::{Tensor, DType, Device};
+ /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
+ /// let b = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0], (2, 3), &Device::Cpu)?;
+ /// // a == b
+ /// # Ok::<(), candle::Error>(())
+ /// ```
fn zeros_impl<S: Into<Shape>>(
shape: S,
dtype: DType,
@@ -150,6 +191,15 @@ impl Tensor {
from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
}
+ /// Create a new tensors filled with zeros
+ ///
+ /// ```rust
+ /// use candle::{Tensor, DType, Device};
+ /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
+ /// let b = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0], (2, 3), &Device::Cpu)?;
+ /// // a == b
+ /// # Ok::<(), candle::Error>(())
+ /// ```
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
Self::zeros_impl(shape, dtype, device, false)
}
@@ -158,6 +208,16 @@ impl Tensor {
Self::zeros_impl(shape, dtype, device, true)
}
+ /// Create a new tensors filled with ones with same shape, dtype, and device
+ /// as the other tensors
+ ///
+ /// ```rust
+ /// use candle::{Tensor, DType, Device};
+ /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
+ /// let b = a.zeros_like()?;
+ /// // b is on CPU f32.
+ /// # Ok::<(), candle::Error>(())
+ /// ```
pub fn zeros_like(&self) -> Result<Self> {
Tensor::zeros(self.shape(), self.dtype(), &self.device())
}
@@ -187,7 +247,7 @@ impl Tensor {
Self::new_impl(array, shape, device, true)
}
- pub fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>(
+ fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>(
data: Vec<D>,
shape: S,
device: &Device,
@@ -986,11 +1046,28 @@ impl Tensor {
self.reshape(dims)
}
+ /// Stacks two or more tensors along a particular dimension.
+ ///
+ /// All tensors must have the same rank, and the output has
+ /// 1 additional rank
+ ///
+ /// ```rust
+ /// # use candle::{Tensor, DType, Device};
+ /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
+ /// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
+ ///
+ /// let c = Tensor::stack(&[&a, &b], 0)?;
+ /// assert_eq!(c.shape().dims(), &[2, 2, 3]);
+ ///
+ /// let c = Tensor::stack(&[&a, &b], 2)?;
+ /// assert_eq!(c.shape().dims(), &[2, 3, 2]);
+ /// # Ok::<(), candle::Error>(())
+ /// ```
pub fn stack<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
if args.is_empty() {
return Err(Error::OpRequiresAtLeastOneTensor { op: "stack" });
}
- let dim = dim.to_index(args[0].as_ref().shape(), "stack")?;
+ let dim = dim.to_index_plus_one(args[0].as_ref().shape(), "stack")?;
let args = args
.iter()
.map(|t| t.as_ref().unsqueeze(dim))
@@ -998,6 +1075,23 @@ impl Tensor {
Self::cat(&args, dim)
}
+ /// Concatenates two or more tensors along a particular dimension.
+ ///
+ /// All tensors must of the same rank, and the output will have
+ /// the same rank
+ ///
+ /// ```rust
+ /// # use candle::{Tensor, DType, Device};
+ /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
+ /// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
+ ///
+ /// let c = Tensor::cat(&[&a, &b], 0)?;
+ /// assert_eq!(c.shape().dims(), &[4, 3]);
+ ///
+ /// let c = Tensor::cat(&[&a, &b], 1)?;
+ /// assert_eq!(c.shape().dims(), &[2, 6]);
+ /// # Ok::<(), candle::Error>(())
+ /// ```
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
if args.is_empty() {
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
@@ -1024,7 +1118,7 @@ impl Tensor {
}
}
- pub fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
+ fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
if args.is_empty() {
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
}