diff options
Diffstat (limited to 'candle-core/src/device.rs')
-rw-r--r-- | candle-core/src/device.rs | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 65232839..84716249 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -81,6 +81,26 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray } } +impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4: usize> NdArray + for &[[[[S; N4]; N3]; N2]; N1] +{ + fn shape(&self) -> Result<Shape> { + Ok(Shape::from((N1, N2, N3, N4))) + } + + fn to_cpu_storage(&self) -> CpuStorage { + let mut vec = Vec::with_capacity(N1 * N2 * N3 * N4); + for i1 in 0..N1 { + for i2 in 0..N2 { + for i3 in 0..N3 { + vec.extend(self[i1][i2][i3]) + } + } + } + S::to_cpu_storage_owned(vec) + } +} + impl Device { pub fn new_cuda(ordinal: usize) -> Result<Self> { Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?)) |