summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/shape.rs13
-rw-r--r--src/tensor.rs15
2 files changed, 19 insertions, 9 deletions
diff --git a/src/shape.rs b/src/shape.rs
index ebc497cf..aa66e706 100644
--- a/src/shape.rs
+++ b/src/shape.rs
@@ -1,7 +1,7 @@
use crate::{Error, Result};
#[derive(Clone, PartialEq, Eq)]
-pub struct Shape(pub(crate) Vec<usize>);
+pub struct Shape(Vec<usize>);
impl std::fmt::Debug for Shape {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@@ -63,6 +63,12 @@ impl From<(usize, usize, usize)> for Shape {
}
}
+impl From<Vec<usize>> for Shape {
+ fn from(dims: Vec<usize>) -> Self {
+ Self(dims)
+ }
+}
+
macro_rules! extract_dims {
($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
pub fn $fn_name(&self) -> Result<$out_type> {
@@ -142,6 +148,11 @@ impl Shape {
}
true
}
+
+ pub fn extend(mut self, additional_dims: &[usize]) -> Self {
+ self.0.extend(additional_dims);
+ self
+ }
}
#[cfg(test)]
diff --git a/src/tensor.rs b/src/tensor.rs
index 161a4787..66807594 100644
--- a/src/tensor.rs
+++ b/src/tensor.rs
@@ -283,9 +283,8 @@ impl Tensor {
});
}
- let mut c_shape: Vec<_> = a_dims[..dim - 2].into();
- c_shape.extend(&[m, n]);
- let c_shape = Shape(c_shape);
+ let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
+ let c_stride = c_shape.stride_contiguous();
let batching: usize = a_dims[..dim - 2].iter().product();
let storage = self.storage.matmul_impl(
@@ -297,8 +296,8 @@ impl Tensor {
let tensor_ = Tensor_ {
id: TensorId::new(),
storage,
- shape: c_shape.clone(),
- stride: c_shape.stride_contiguous(),
+ shape: c_shape,
+ stride: c_stride,
op: Some(Op::Matmul(self.clone(), rhs.clone())),
is_variable: false,
};
@@ -414,7 +413,6 @@ impl Tensor {
pub fn t(&self) -> Result<Tensor> {
let mut stride = self.stride().to_vec();
- let mut shape = self.shape().clone();
let n = stride.len();
if n < 2 {
return Err(Error::UnexpectedNumberOfDims {
@@ -423,12 +421,13 @@ impl Tensor {
shape: self.shape().clone(),
});
}
- (shape.0[n - 2], shape.0[n - 1]) = (shape.0[n - 1], shape.0[n - 2]);
+ let mut dims = self.shape().dims().to_vec();
+ (dims[n - 2], dims[n - 1]) = (dims[n - 1], dims[n - 2]);
(stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]);
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
- shape,
+ shape: Shape::from(dims),
stride,
// TODO The op should have a backward
op: None,