summaryrefslogtreecommitdiff
path: root/candle-core/src/device.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/device.rs')
-rw-r--r--candle-core/src/device.rs20
1 files changed, 20 insertions, 0 deletions
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs
index 65232839..84716249 100644
--- a/candle-core/src/device.rs
+++ b/candle-core/src/device.rs
@@ -81,6 +81,26 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
}
}
+impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4: usize> NdArray
+ for &[[[[S; N4]; N3]; N2]; N1]
+{
+ fn shape(&self) -> Result<Shape> {
+ Ok(Shape::from((N1, N2, N3, N4)))
+ }
+
+ fn to_cpu_storage(&self) -> CpuStorage {
+ let mut vec = Vec::with_capacity(N1 * N2 * N3 * N4);
+ for i1 in 0..N1 {
+ for i2 in 0..N2 {
+ for i3 in 0..N3 {
+ vec.extend(self[i1][i2][i3])
+ }
+ }
+ }
+ S::to_cpu_storage_owned(vec)
+ }
+}
+
impl Device {
pub fn new_cuda(ordinal: usize) -> Result<Self> {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))