diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/cpu_backend.rs | 2 | ||||
-rw-r--r-- | candle-core/src/cuda_backend.rs | 2 | ||||
-rw-r--r-- | candle-core/src/shape.rs | 18 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 12 |
4 files changed, 20 insertions, 14 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index b8d52c95..82e1f3e2 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1688,7 +1688,7 @@ impl BackendStorage for CpuStorage { fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> { let ids = self.as_slice::<u32>()?; - let (vocab_size, hidden_size) = rhs_l.shape().r2()?; + let (vocab_size, hidden_size) = rhs_l.shape().dims2()?; Embedding { vocab_size, hidden_size, diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 43bfef2d..f9fefe17 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -620,7 +620,7 @@ impl<'a> Map1 for Embedding<'a> { let shape = ids_l.shape(); let (v_size, h_size) = rhs_l .shape() - .r2() + .dims2() .map_err(|e| CudaError::WrappedError(Box::new(e))) .w()?; let dims = shape.dims(); diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index 982f9db0..b016ead5 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -87,6 +87,12 @@ macro_rules! extract_dims { } } } + impl crate::Tensor { + pub fn $fn_name(&self) -> Result<$out_type> { + self.shape().$fn_name() + } + } + impl std::convert::TryInto<$out_type> for Shape { type Error = crate::Error; fn try_into(self) -> std::result::Result<$out_type, Self::Error> { @@ -328,23 +334,23 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) { } } -extract_dims!(r0, 0, |_: &Vec<usize>| (), ()); -extract_dims!(r1, 1, |d: &[usize]| d[0], usize); -extract_dims!(r2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize)); +extract_dims!(dims0, 0, |_: &Vec<usize>| (), ()); +extract_dims!(dims1, 1, |d: &[usize]| d[0], usize); +extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize)); extract_dims!( - r3, + dims3, 3, |d: &[usize]| (d[0], d[1], d[2]), (usize, usize, usize) ); extract_dims!( - r4, + dims4, 4, |d: &[usize]| (d[0], d[1], d[2], d[3]), (usize, usize, usize, usize) ); extract_dims!( - r5, + dims5, 5, |d: &[usize]| (d[0], d[1], d[2], d[3], d[4]), (usize, usize, usize, usize, usize) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 8ba0ba43..561f1863 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -772,7 +772,7 @@ impl Tensor { /// Applies a 1D convolution over the input tensor. pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> { - let (c_out, c_in_k, k_size) = kernel.shape().r3()?; + let (c_out, c_in_k, k_size) = kernel.dims3()?; let (b_size, c_in, l_in) = match *self.dims() { [b_size, c_in, l_in] => (Some(b_size), c_in, l_in), [c_in, l_in] => (None, c_in, l_in), @@ -931,8 +931,8 @@ impl Tensor { .bt())? } let ids_shape = ids.shape(); - let seq_len = ids_shape.r1()?; - let (_, hidden_size) = rhs.shape().r2()?; + let seq_len = ids_shape.dims1()?; + let (_, hidden_size) = rhs.dims2()?; let storage = ids .storage() .embedding(ids.layout(), &rhs.storage(), rhs.layout())?; @@ -1013,7 +1013,7 @@ impl Tensor { // The number of element in indexes must match the dimension on which the add is // performed on the source tensor (and the index values from `indexes` are taken from // the target tensor self) - mismatch || source_dims[dim] != indexes.shape().r1()? + mismatch || source_dims[dim] != indexes.dims1()? }; if mismatch { Err(Error::ShapeMismatchBinaryOp { @@ -1144,7 +1144,7 @@ impl Tensor { /// Returns the data contained in a 2D tensor as a vector of vector of scalar values. pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> { - let (dim1, dim2) = self.shape().r2()?; + let (dim1, dim2) = self.dims2()?; let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { let data = S::cpu_storage_as_slice(cpu_storage)?; let mut rows = vec![]; @@ -1164,7 +1164,7 @@ impl Tensor { /// Returns the data contained in a 3D tensor. pub fn to_vec3<S: crate::WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>> { - let (dim1, dim2, dim3) = self.shape().r3()?; + let (dim1, dim2, dim3) = self.dims3()?; let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { let data = S::cpu_storage_as_slice(cpu_storage)?; let mut top_rows = vec![]; |