diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-22 11:39:27 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-22 10:39:27 +0100 |
commit | 43c72232927ca80c850a73ce977c2063d5a2dcf5 (patch) | |
tree | c93c07984e06b1925313f4f641a8b1a3956fc0ed /candle-core/src/tensor.rs | |
parent | 52c5d8c087f6a2ee91b807467860eb3e96bb6267 (diff) | |
download | candle-43c72232927ca80c850a73ce977c2063d5a2dcf5.tar.gz candle-43c72232927ca80c850a73ce977c2063d5a2dcf5.tar.bz2 candle-43c72232927ca80c850a73ce977c2063d5a2dcf5.zip |
Rename the .r functions to .dims so as to be a bit more explicit. (#220)
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 8ba0ba43..561f1863 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -772,7 +772,7 @@ impl Tensor { /// Applies a 1D convolution over the input tensor. pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> { - let (c_out, c_in_k, k_size) = kernel.shape().r3()?; + let (c_out, c_in_k, k_size) = kernel.dims3()?; let (b_size, c_in, l_in) = match *self.dims() { [b_size, c_in, l_in] => (Some(b_size), c_in, l_in), [c_in, l_in] => (None, c_in, l_in), @@ -931,8 +931,8 @@ impl Tensor { .bt())? } let ids_shape = ids.shape(); - let seq_len = ids_shape.r1()?; - let (_, hidden_size) = rhs.shape().r2()?; + let seq_len = ids_shape.dims1()?; + let (_, hidden_size) = rhs.dims2()?; let storage = ids .storage() .embedding(ids.layout(), &rhs.storage(), rhs.layout())?; @@ -1013,7 +1013,7 @@ impl Tensor { // The number of element in indexes must match the dimension on which the add is // performed on the source tensor (and the index values from `indexes` are taken from // the target tensor self) - mismatch || source_dims[dim] != indexes.shape().r1()? + mismatch || source_dims[dim] != indexes.dims1()? }; if mismatch { Err(Error::ShapeMismatchBinaryOp { @@ -1144,7 +1144,7 @@ impl Tensor { /// Returns the data contained in a 2D tensor as a vector of vector of scalar values. pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> { - let (dim1, dim2) = self.shape().r2()?; + let (dim1, dim2) = self.dims2()?; let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { let data = S::cpu_storage_as_slice(cpu_storage)?; let mut rows = vec![]; @@ -1164,7 +1164,7 @@ impl Tensor { /// Returns the data contained in a 3D tensor. pub fn to_vec3<S: crate::WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>> { - let (dim1, dim2, dim3) = self.shape().r3()?; + let (dim1, dim2, dim3) = self.dims3()?; let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { let data = S::cpu_storage_as_slice(cpu_storage)?; let mut top_rows = vec![]; |