summaryrefslogtreecommitdiff
path: root/src/tensor.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/tensor.rs')
-rw-r--r--src/tensor.rs36
1 files changed, 19 insertions, 17 deletions
diff --git a/src/tensor.rs b/src/tensor.rs
index 02105573..e8e01d5c 100644
--- a/src/tensor.rs
+++ b/src/tensor.rs
@@ -204,12 +204,13 @@ impl Tensor {
shape: self.shape().clone(),
});
}
+ let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
+ let data = S::cpu_storage_as_slice(cpu_storage)?;
+ Ok::<_, Error>(data[0])
+ };
match &self.storage {
- Storage::Cpu(cpu_storage) => {
- let data = S::cpu_storage_as_slice(cpu_storage)?;
- Ok(data[0])
- }
- Storage::Cuda { .. } => todo!(),
+ Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
+ Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
}
@@ -261,19 +262,20 @@ impl Tensor {
pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> {
let (dim1, dim2) = self.shape().r2()?;
- match &self.storage {
- Storage::Cpu(cpu_storage) => {
- let data = S::cpu_storage_as_slice(cpu_storage)?;
- let mut rows = vec![];
- let mut src_index = self.strided_index();
- for _idx_row in 0..dim1 {
- let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect();
- rows.push(row)
- }
- assert!(src_index.next().is_none());
- Ok(rows)
+ let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
+ let data = S::cpu_storage_as_slice(cpu_storage)?;
+ let mut rows = vec![];
+ let mut src_index = self.strided_index();
+ for _idx_row in 0..dim1 {
+ let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect();
+ rows.push(row)
}
- Storage::Cuda { .. } => todo!(),
+ assert!(src_index.next().is_none());
+ Ok(rows)
+ };
+ match &self.storage {
+ Storage::Cpu(storage) => from_cpu_storage(storage),
+ Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
}