diff options
author | laurent <laurent.mazare@gmail.com> | 2023-06-22 11:01:49 +0100 |
---|---|---|
committer | laurent <laurent.mazare@gmail.com> | 2023-06-22 11:01:49 +0100 |
commit | 87a37b3bf3b6fd5034269c10c21c8f91e0223eb0 (patch) | |
tree | 7a25c28df9bb0eda94a89a61da95c8cdf0f55c06 /src | |
parent | 083ced4428819f123ce8549ead9163055ac1ac64 (diff) | |
download | candle-87a37b3bf3b6fd5034269c10c21c8f91e0223eb0.tar.gz candle-87a37b3bf3b6fd5034269c10c21c8f91e0223eb0.tar.bz2 candle-87a37b3bf3b6fd5034269c10c21c8f91e0223eb0.zip |
Retrieve data from the gpu.
Diffstat (limited to 'src')
-rw-r--r-- | src/tensor.rs | 36 |
1 files changed, 19 insertions, 17 deletions
diff --git a/src/tensor.rs b/src/tensor.rs index 02105573..e8e01d5c 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -204,12 +204,13 @@ impl Tensor { shape: self.shape().clone(), }); } + let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { + let data = S::cpu_storage_as_slice(cpu_storage)?; + Ok::<_, Error>(data[0]) + }; match &self.storage { - Storage::Cpu(cpu_storage) => { - let data = S::cpu_storage_as_slice(cpu_storage)?; - Ok(data[0]) - } - Storage::Cuda { .. } => todo!(), + Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage), + Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), } } @@ -261,19 +262,20 @@ impl Tensor { pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> { let (dim1, dim2) = self.shape().r2()?; - match &self.storage { - Storage::Cpu(cpu_storage) => { - let data = S::cpu_storage_as_slice(cpu_storage)?; - let mut rows = vec![]; - let mut src_index = self.strided_index(); - for _idx_row in 0..dim1 { - let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect(); - rows.push(row) - } - assert!(src_index.next().is_none()); - Ok(rows) + let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { + let data = S::cpu_storage_as_slice(cpu_storage)?; + let mut rows = vec![]; + let mut src_index = self.strided_index(); + for _idx_row in 0..dim1 { + let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect(); + rows.push(row) } - Storage::Cuda { .. } => todo!(), + assert!(src_index.next().is_none()); + Ok(rows) + }; + match &self.storage { + Storage::Cpu(storage) => from_cpu_storage(storage), + Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), } } |