diff options
author | laurent <laurent.mazare@gmail.com> | 2023-07-05 20:22:43 +0100 |
---|---|---|
committer | laurent <laurent.mazare@gmail.com> | 2023-07-05 20:22:43 +0100 |
commit | 2c3d871b2e7490cb3740674647a03b0dcc8f67b6 (patch) | |
tree | 7c8d867001fbecec127ad9581056c0fd2f67f2a3 /candle-core/src/tensor.rs | |
parent | b7388bbf718f9301b7e41e222654217f18e4c1e1 (diff) | |
download | candle-2c3d871b2e7490cb3740674647a03b0dcc8f67b6.tar.gz candle-2c3d871b2e7490cb3740674647a03b0dcc8f67b6.tar.bz2 candle-2c3d871b2e7490cb3740674647a03b0dcc8f67b6.zip |
Add a simpler way to specify the dim index for some ops.
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 59 |
1 files changed, 40 insertions, 19 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 95f663f0..1eb92e6a 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1,3 +1,4 @@ +use crate::shape::Dim; use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape}; use std::sync::Arc; @@ -362,9 +363,9 @@ impl Tensor { /// Returns a new tensor that is a narrowed version of the input, the dimension `dim` /// ranges from `start` to `start + len`. - pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> { + pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> { let dims = self.dims(); - self.check_dim(dim, "narrow")?; + let dim = dim.to_index(self.shape(), "narrow")?; if start + len > dims[dim] { Err(Error::NarrowInvalidArgs { shape: self.shape().clone(), @@ -392,8 +393,8 @@ impl Tensor { } } - pub fn softmax(&self, dim: usize) -> Result<Self> { - self.check_dim(dim, "softmax")?; + pub fn softmax<D: Dim>(&self, dim: D) -> Result<Self> { + let dim = dim.to_index(self.shape(), "softmax")?; // TODO: unify the two branches. if self.device().is_cuda() { // We do not have a cuda kernel for divide_by_sum_over_dim so split @@ -692,14 +693,22 @@ impl Tensor { self.sum(&dims) } - pub fn flatten(&self, start_dim: Option<usize>, end_dim: Option<usize>) -> Result<Tensor> { + fn flatten_<D1: Dim, D2: Dim>( + &self, + start_dim: Option<D1>, + end_dim: Option<D2>, + ) -> Result<Tensor> { if self.rank() == 0 { self.reshape(1) } else { - let start_dim = start_dim.unwrap_or(0); - let end_dim = end_dim.unwrap_or_else(|| self.rank() - 1); - self.check_dim(start_dim, "flatten")?; - self.check_dim(end_dim, "flatten")?; + let start_dim = match start_dim { + None => 0, + Some(dim) => dim.to_index(self.shape(), "flatten")?, + }; + let end_dim = match end_dim { + None => self.rank() - 1, + Some(dim) => dim.to_index(self.shape(), "flatten")?, + }; if start_dim < end_dim { let dims = self.dims(); let mut dst_dims = dims[..start_dim].to_vec(); @@ -714,8 +723,20 @@ impl Tensor { } } + pub fn flatten<D1: Dim, D2: Dim>(&self, start_dim: D1, end_dim: D2) -> Result<Tensor> { + self.flatten_(Some(start_dim), Some(end_dim)) + } + + pub fn flatten_to<D: Dim>(&self, end_dim: D) -> Result<Tensor> { + self.flatten_(None::<usize>, Some(end_dim)) + } + + pub fn flatten_from<D: Dim>(&self, start_dim: D) -> Result<Tensor> { + self.flatten_(Some(start_dim), None::<usize>) + } + pub fn flatten_all(&self) -> Result<Tensor> { - self.flatten(None, None) + self.flatten_(None::<usize>, None::<usize>) } pub fn get(&self, i: usize) -> Result<Tensor> { @@ -743,9 +764,9 @@ impl Tensor { /// Returns a tensor that is a transposed version of the input, the given dimensions are /// swapped. - pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Tensor> { - self.check_dim(dim1, "transpose")?; - self.check_dim(dim2, "transpose")?; + pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> { + let dim1 = dim1.to_index(self.shape(), "transpose")?; + let dim2 = dim2.to_index(self.shape(), "transpose")?; let op = if self.track_op() { Some(Op::Transpose(self.clone(), dim1, dim2)) } else { @@ -929,23 +950,23 @@ impl Tensor { } } - pub fn squeeze(&self, index: usize) -> Result<Self> { + pub fn squeeze<D: Dim>(&self, dim: D) -> Result<Self> { // The PyTorch semantics are to return the same tensor if the target dimension // does not have a size of 1. let dims = self.dims(); - self.check_dim(index, "squeeze")?; - if dims[index] == 1 { + let dim = dim.to_index(self.shape(), "squeeze")?; + if dims[dim] == 1 { let mut dims = dims.to_vec(); - dims.remove(index); + dims.remove(dim); self.reshape(dims) } else { Ok(self.clone()) } } - pub fn unsqueeze(&self, index: usize) -> Result<Self> { + pub fn unsqueeze(&self, dim: usize) -> Result<Self> { let mut dims = self.dims().to_vec(); - dims.insert(index, 1); + dims.insert(dim, 1); self.reshape(dims) } |