summaryrefslogtreecommitdiff
path: root/src/tensor.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/tensor.rs')
-rw-r--r--src/tensor.rs31
1 files changed, 16 insertions, 15 deletions
diff --git a/src/tensor.rs b/src/tensor.rs
index 09e5d66c..161a4787 100644
--- a/src/tensor.rs
+++ b/src/tensor.rs
@@ -359,23 +359,24 @@ impl Tensor {
pub fn to_vec3<S: crate::WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>> {
let (dim1, dim2, dim3) = self.shape().r3()?;
- match &self.storage {
- Storage::Cpu(cpu_storage) => {
- let data = S::cpu_storage_as_slice(cpu_storage)?;
- let mut top_rows = vec![];
- let mut src_index = self.strided_index();
- for _idx in 0..dim1 {
- let mut rows = vec![];
- for _jdx in 0..dim2 {
- let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect();
- rows.push(row)
- }
- top_rows.push(rows);
+ let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
+ let data = S::cpu_storage_as_slice(cpu_storage)?;
+ let mut top_rows = vec![];
+ let mut src_index = self.strided_index();
+ for _idx in 0..dim1 {
+ let mut rows = vec![];
+ for _jdx in 0..dim2 {
+ let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect();
+ rows.push(row)
}
- assert!(src_index.next().is_none());
- Ok(top_rows)
+ top_rows.push(rows);
}
- Storage::Cuda { .. } => todo!(),
+ assert!(src_index.next().is_none());
+ Ok(top_rows)
+ };
+ match &self.storage {
+ Storage::Cpu(storage) => from_cpu_storage(storage),
+ Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
}