summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2023-06-22 11:01:49 +0100
committerlaurent <laurent.mazare@gmail.com>2023-06-22 11:01:49 +0100
commit87a37b3bf3b6fd5034269c10c21c8f91e0223eb0 (patch)
tree7a25c28df9bb0eda94a89a61da95c8cdf0f55c06 /src
parent083ced4428819f123ce8549ead9163055ac1ac64 (diff)
downloadcandle-87a37b3bf3b6fd5034269c10c21c8f91e0223eb0.tar.gz
candle-87a37b3bf3b6fd5034269c10c21c8f91e0223eb0.tar.bz2
candle-87a37b3bf3b6fd5034269c10c21c8f91e0223eb0.zip
Retrieve data from the gpu.
Diffstat (limited to 'src')
-rw-r--r--src/tensor.rs36
1 files changed, 19 insertions, 17 deletions
diff --git a/src/tensor.rs b/src/tensor.rs
index 02105573..e8e01d5c 100644
--- a/src/tensor.rs
+++ b/src/tensor.rs
@@ -204,12 +204,13 @@ impl Tensor {
shape: self.shape().clone(),
});
}
+ let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
+ let data = S::cpu_storage_as_slice(cpu_storage)?;
+ Ok::<_, Error>(data[0])
+ };
match &self.storage {
- Storage::Cpu(cpu_storage) => {
- let data = S::cpu_storage_as_slice(cpu_storage)?;
- Ok(data[0])
- }
- Storage::Cuda { .. } => todo!(),
+ Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
+ Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
}
@@ -261,19 +262,20 @@ impl Tensor {
pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> {
let (dim1, dim2) = self.shape().r2()?;
- match &self.storage {
- Storage::Cpu(cpu_storage) => {
- let data = S::cpu_storage_as_slice(cpu_storage)?;
- let mut rows = vec![];
- let mut src_index = self.strided_index();
- for _idx_row in 0..dim1 {
- let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect();
- rows.push(row)
- }
- assert!(src_index.next().is_none());
- Ok(rows)
+ let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
+ let data = S::cpu_storage_as_slice(cpu_storage)?;
+ let mut rows = vec![];
+ let mut src_index = self.strided_index();
+ for _idx_row in 0..dim1 {
+ let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect();
+ rows.push(row)
}
- Storage::Cuda { .. } => todo!(),
+ assert!(src_index.next().is_none());
+ Ok(rows)
+ };
+ match &self.storage {
+ Storage::Cpu(storage) => from_cpu_storage(storage),
+ Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
}