summaryrefslogtreecommitdiff
path: root/candle-core/src/tensor.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-22 11:39:27 +0200
committerGitHub <noreply@github.com>2023-07-22 10:39:27 +0100
commit43c72232927ca80c850a73ce977c2063d5a2dcf5 (patch)
treec93c07984e06b1925313f4f641a8b1a3956fc0ed /candle-core/src/tensor.rs
parent52c5d8c087f6a2ee91b807467860eb3e96bb6267 (diff)
downloadcandle-43c72232927ca80c850a73ce977c2063d5a2dcf5.tar.gz
candle-43c72232927ca80c850a73ce977c2063d5a2dcf5.tar.bz2
candle-43c72232927ca80c850a73ce977c2063d5a2dcf5.zip
Rename the .r functions to .dims so as to be a bit more explicit. (#220)
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r--candle-core/src/tensor.rs12
1 files changed, 6 insertions, 6 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 8ba0ba43..561f1863 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -772,7 +772,7 @@ impl Tensor {
/// Applies a 1D convolution over the input tensor.
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
- let (c_out, c_in_k, k_size) = kernel.shape().r3()?;
+ let (c_out, c_in_k, k_size) = kernel.dims3()?;
let (b_size, c_in, l_in) = match *self.dims() {
[b_size, c_in, l_in] => (Some(b_size), c_in, l_in),
[c_in, l_in] => (None, c_in, l_in),
@@ -931,8 +931,8 @@ impl Tensor {
.bt())?
}
let ids_shape = ids.shape();
- let seq_len = ids_shape.r1()?;
- let (_, hidden_size) = rhs.shape().r2()?;
+ let seq_len = ids_shape.dims1()?;
+ let (_, hidden_size) = rhs.dims2()?;
let storage = ids
.storage()
.embedding(ids.layout(), &rhs.storage(), rhs.layout())?;
@@ -1013,7 +1013,7 @@ impl Tensor {
// The number of element in indexes must match the dimension on which the add is
// performed on the source tensor (and the index values from `indexes` are taken from
// the target tensor self)
- mismatch || source_dims[dim] != indexes.shape().r1()?
+ mismatch || source_dims[dim] != indexes.dims1()?
};
if mismatch {
Err(Error::ShapeMismatchBinaryOp {
@@ -1144,7 +1144,7 @@ impl Tensor {
/// Returns the data contained in a 2D tensor as a vector of vector of scalar values.
pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> {
- let (dim1, dim2) = self.shape().r2()?;
+ let (dim1, dim2) = self.dims2()?;
let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
let data = S::cpu_storage_as_slice(cpu_storage)?;
let mut rows = vec![];
@@ -1164,7 +1164,7 @@ impl Tensor {
/// Returns the data contained in a 3D tensor.
pub fn to_vec3<S: crate::WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>> {
- let (dim1, dim2, dim3) = self.shape().r3()?;
+ let (dim1, dim2, dim3) = self.dims3()?;
let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
let data = S::cpu_storage_as_slice(cpu_storage)?;
let mut top_rows = vec![];