diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-06-21 21:37:54 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-21 21:37:54 +0100 |
commit | db35b310504ab97044b2c3826de72f9bccf86415 (patch) | |
tree | 710596156a4c026d4dd2ba804fab79b6cdafae3b /src/tensor.rs | |
parent | 7c317f9611c263f10d661b44151d3655a2fa3b90 (diff) | |
parent | 7c46de9584fd4315b84d3bc4c28cf1b2bad7785d (diff) | |
download | candle-db35b310504ab97044b2c3826de72f9bccf86415.tar.gz candle-db35b310504ab97044b2c3826de72f9bccf86415.tar.bz2 candle-db35b310504ab97044b2c3826de72f9bccf86415.zip |
Merge pull request #3 from LaurentMazare/cuda
Add Cuda support.
Diffstat (limited to 'src/tensor.rs')
-rw-r--r-- | src/tensor.rs | 38 |
1 files changed, 18 insertions, 20 deletions
diff --git a/src/tensor.rs b/src/tensor.rs index 9ba412f9..02105573 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -84,7 +84,7 @@ impl Tensor { fn ones_impl<S: Into<Shape>>( shape: S, dtype: DType, - device: Device, + device: &Device, is_variable: bool, ) -> Result<Self> { let shape = shape.into(); @@ -101,22 +101,22 @@ impl Tensor { Ok(Self(Arc::new(tensor_))) } - pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> { + pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> { Self::ones_impl(shape, dtype, device, false) } - pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> { + pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> { Self::ones_impl(shape, dtype, device, true) } pub fn ones_like(&self) -> Result<Self> { - Tensor::ones(self.shape(), self.dtype(), self.device()) + Tensor::ones(self.shape(), self.dtype(), &self.device()) } fn zeros_impl<S: Into<Shape>>( shape: S, dtype: DType, - device: Device, + device: &Device, is_variable: bool, ) -> Result<Self> { let shape = shape.into(); @@ -133,21 +133,21 @@ impl Tensor { Ok(Self(Arc::new(tensor_))) } - pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> { + pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> { Self::zeros_impl(shape, dtype, device, false) } - pub fn zeros_var<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> { + pub fn zeros_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> { Self::zeros_impl(shape, dtype, device, true) } pub fn zeros_like(&self) -> Result<Self> { - Tensor::zeros(self.shape(), self.dtype(), self.device()) + Tensor::zeros(self.shape(), self.dtype(), &self.device()) } pub fn new_impl<A: crate::device::NdArray>( array: A, - device: Device, + device: &Device, is_variable: bool, ) -> Result<Self> { let shape = array.shape()?; @@ -164,11 +164,11 @@ impl Tensor { Ok(Self(Arc::new(tensor_))) } - pub fn new<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> { + pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> { Self::new_impl(array, device, false) } - pub fn var<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> { + pub fn var<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> { Self::new_impl(array, device, true) } @@ -250,7 +250,12 @@ impl Tensor { let data = S::cpu_storage_as_slice(cpu_storage)?; Ok(self.strided_index().map(|i| data[i]).collect()) } - Storage::Cuda { .. } => todo!(), + Storage::Cuda(slice) => { + // TODO: Would it be possible to only fetch the necessary data? + let cpu_storage = slice.to_cpu_storage()?; + let data = S::cpu_storage_as_slice(&cpu_storage)?; + Ok(self.strided_index().map(|i| data[i]).collect()) + } } } @@ -305,14 +310,7 @@ impl Tensor { } pub fn is_contiguous(&self) -> bool { - let mut acc = 1; - for (&stride, &dim) in self.stride.iter().zip(self.shape.dims().iter()).rev() { - if stride != acc { - return false; - } - acc *= dim; - } - true + self.shape.is_contiguous(&self.stride) } /// Return all the nodes that lead to this value in a topologically sorted vec, the first |