summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/lib.rs2
-rw-r--r--candle-core/src/shape.rs39
-rw-r--r--candle-core/src/tensor.rs59
3 files changed, 80 insertions, 20 deletions
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs
index 2365a34d..9a2602f4 100644
--- a/candle-core/src/lib.rs
+++ b/candle-core/src/lib.rs
@@ -23,7 +23,7 @@ pub use device::{Device, DeviceLocation};
pub use dtype::{DType, WithDType};
pub use error::{Error, Result};
pub use layout::Layout;
-pub use shape::Shape;
+pub use shape::{Shape, D};
pub use storage::Storage;
use strided_index::StridedIndex;
pub use tensor::{Tensor, TensorId};
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs
index cc068004..1152dc3e 100644
--- a/candle-core/src/shape.rs
+++ b/candle-core/src/shape.rs
@@ -183,6 +183,45 @@ impl Shape {
}
}
+pub trait Dim {
+ fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize>;
+}
+
+impl Dim for usize {
+ fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
+ let dim = *self;
+ if dim >= shape.dims().len() {
+ Err(Error::DimOutOfRange {
+ shape: shape.clone(),
+ dim,
+ op,
+ })?
+ } else {
+ Ok(dim)
+ }
+ }
+}
+
+pub enum D {
+ Minus1,
+ Minus2,
+}
+
+impl Dim for D {
+ fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
+ let rank = shape.rank();
+ match self {
+ Self::Minus1 if rank >= 1 => Ok(rank - 1),
+ Self::Minus2 if rank >= 2 => Ok(rank - 2),
+ _ => Err(Error::DimOutOfRange {
+ shape: shape.clone(),
+ dim: 42, // TODO: Have an adequate error
+ op,
+ }),
+ }
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 95f663f0..1eb92e6a 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -1,3 +1,4 @@
+use crate::shape::Dim;
use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::Arc;
@@ -362,9 +363,9 @@ impl Tensor {
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
/// ranges from `start` to `start + len`.
- pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
+ pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
let dims = self.dims();
- self.check_dim(dim, "narrow")?;
+ let dim = dim.to_index(self.shape(), "narrow")?;
if start + len > dims[dim] {
Err(Error::NarrowInvalidArgs {
shape: self.shape().clone(),
@@ -392,8 +393,8 @@ impl Tensor {
}
}
- pub fn softmax(&self, dim: usize) -> Result<Self> {
- self.check_dim(dim, "softmax")?;
+ pub fn softmax<D: Dim>(&self, dim: D) -> Result<Self> {
+ let dim = dim.to_index(self.shape(), "softmax")?;
// TODO: unify the two branches.
if self.device().is_cuda() {
// We do not have a cuda kernel for divide_by_sum_over_dim so split
@@ -692,14 +693,22 @@ impl Tensor {
self.sum(&dims)
}
- pub fn flatten(&self, start_dim: Option<usize>, end_dim: Option<usize>) -> Result<Tensor> {
+ fn flatten_<D1: Dim, D2: Dim>(
+ &self,
+ start_dim: Option<D1>,
+ end_dim: Option<D2>,
+ ) -> Result<Tensor> {
if self.rank() == 0 {
self.reshape(1)
} else {
- let start_dim = start_dim.unwrap_or(0);
- let end_dim = end_dim.unwrap_or_else(|| self.rank() - 1);
- self.check_dim(start_dim, "flatten")?;
- self.check_dim(end_dim, "flatten")?;
+ let start_dim = match start_dim {
+ None => 0,
+ Some(dim) => dim.to_index(self.shape(), "flatten")?,
+ };
+ let end_dim = match end_dim {
+ None => self.rank() - 1,
+ Some(dim) => dim.to_index(self.shape(), "flatten")?,
+ };
if start_dim < end_dim {
let dims = self.dims();
let mut dst_dims = dims[..start_dim].to_vec();
@@ -714,8 +723,20 @@ impl Tensor {
}
}
+ pub fn flatten<D1: Dim, D2: Dim>(&self, start_dim: D1, end_dim: D2) -> Result<Tensor> {
+ self.flatten_(Some(start_dim), Some(end_dim))
+ }
+
+ pub fn flatten_to<D: Dim>(&self, end_dim: D) -> Result<Tensor> {
+ self.flatten_(None::<usize>, Some(end_dim))
+ }
+
+ pub fn flatten_from<D: Dim>(&self, start_dim: D) -> Result<Tensor> {
+ self.flatten_(Some(start_dim), None::<usize>)
+ }
+
pub fn flatten_all(&self) -> Result<Tensor> {
- self.flatten(None, None)
+ self.flatten_(None::<usize>, None::<usize>)
}
pub fn get(&self, i: usize) -> Result<Tensor> {
@@ -743,9 +764,9 @@ impl Tensor {
/// Returns a tensor that is a transposed version of the input, the given dimensions are
/// swapped.
- pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Tensor> {
- self.check_dim(dim1, "transpose")?;
- self.check_dim(dim2, "transpose")?;
+ pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
+ let dim1 = dim1.to_index(self.shape(), "transpose")?;
+ let dim2 = dim2.to_index(self.shape(), "transpose")?;
let op = if self.track_op() {
Some(Op::Transpose(self.clone(), dim1, dim2))
} else {
@@ -929,23 +950,23 @@ impl Tensor {
}
}
- pub fn squeeze(&self, index: usize) -> Result<Self> {
+ pub fn squeeze<D: Dim>(&self, dim: D) -> Result<Self> {
// The PyTorch semantics are to return the same tensor if the target dimension
// does not have a size of 1.
let dims = self.dims();
- self.check_dim(index, "squeeze")?;
- if dims[index] == 1 {
+ let dim = dim.to_index(self.shape(), "squeeze")?;
+ if dims[dim] == 1 {
let mut dims = dims.to_vec();
- dims.remove(index);
+ dims.remove(dim);
self.reshape(dims)
} else {
Ok(self.clone())
}
}
- pub fn unsqueeze(&self, index: usize) -> Result<Self> {
+ pub fn unsqueeze(&self, dim: usize) -> Result<Self> {
let mut dims = self.dims().to_vec();
- dims.insert(index, 1);
+ dims.insert(dim, 1);
self.reshape(dims)
}