summaryrefslogtreecommitdiff
path: root/src/tensor.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/tensor.rs')
-rw-r--r--src/tensor.rs38
1 files changed, 18 insertions, 20 deletions
diff --git a/src/tensor.rs b/src/tensor.rs
index 9ba412f9..02105573 100644
--- a/src/tensor.rs
+++ b/src/tensor.rs
@@ -84,7 +84,7 @@ impl Tensor {
fn ones_impl<S: Into<Shape>>(
shape: S,
dtype: DType,
- device: Device,
+ device: &Device,
is_variable: bool,
) -> Result<Self> {
let shape = shape.into();
@@ -101,22 +101,22 @@ impl Tensor {
Ok(Self(Arc::new(tensor_)))
}
- pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> {
+ pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
Self::ones_impl(shape, dtype, device, false)
}
- pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> {
+ pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
Self::ones_impl(shape, dtype, device, true)
}
pub fn ones_like(&self) -> Result<Self> {
- Tensor::ones(self.shape(), self.dtype(), self.device())
+ Tensor::ones(self.shape(), self.dtype(), &self.device())
}
fn zeros_impl<S: Into<Shape>>(
shape: S,
dtype: DType,
- device: Device,
+ device: &Device,
is_variable: bool,
) -> Result<Self> {
let shape = shape.into();
@@ -133,21 +133,21 @@ impl Tensor {
Ok(Self(Arc::new(tensor_)))
}
- pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> {
+ pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
Self::zeros_impl(shape, dtype, device, false)
}
- pub fn zeros_var<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> {
+ pub fn zeros_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
Self::zeros_impl(shape, dtype, device, true)
}
pub fn zeros_like(&self) -> Result<Self> {
- Tensor::zeros(self.shape(), self.dtype(), self.device())
+ Tensor::zeros(self.shape(), self.dtype(), &self.device())
}
pub fn new_impl<A: crate::device::NdArray>(
array: A,
- device: Device,
+ device: &Device,
is_variable: bool,
) -> Result<Self> {
let shape = array.shape()?;
@@ -164,11 +164,11 @@ impl Tensor {
Ok(Self(Arc::new(tensor_)))
}
- pub fn new<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> {
+ pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
Self::new_impl(array, device, false)
}
- pub fn var<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> {
+ pub fn var<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
Self::new_impl(array, device, true)
}
@@ -250,7 +250,12 @@ impl Tensor {
let data = S::cpu_storage_as_slice(cpu_storage)?;
Ok(self.strided_index().map(|i| data[i]).collect())
}
- Storage::Cuda { .. } => todo!(),
+ Storage::Cuda(slice) => {
+ // TODO: Would it be possible to only fetch the necessary data?
+ let cpu_storage = slice.to_cpu_storage()?;
+ let data = S::cpu_storage_as_slice(&cpu_storage)?;
+ Ok(self.strided_index().map(|i| data[i]).collect())
+ }
}
}
@@ -305,14 +310,7 @@ impl Tensor {
}
pub fn is_contiguous(&self) -> bool {
- let mut acc = 1;
- for (&stride, &dim) in self.stride.iter().zip(self.shape.dims().iter()).rev() {
- if stride != acc {
- return false;
- }
- acc *= dim;
- }
- true
+ self.shape.is_contiguous(&self.stride)
}
/// Return all the nodes that lead to this value in a topologically sorted vec, the first