diff options
Diffstat (limited to 'candle-core/src/shape.rs')
-rw-r--r-- | candle-core/src/shape.rs | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index db0fe98a..578e8ac9 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -425,6 +425,17 @@ impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim> Dims for (D1, D2, D3, D4) { } } +impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5) { + fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> { + let d0 = self.0.to_index(shape, op)?; + let d1 = self.1.to_index(shape, op)?; + let d2 = self.2.to_index(shape, op)?; + let d3 = self.3.to_index(shape, op)?; + let d4 = self.4.to_index(shape, op)?; + Ok(vec![d0, d1, d2, d3, d4]) + } +} + extract_dims!(dims0, 0, |_: &[usize]| (), ()); extract_dims!(dims1, 1, |d: &[usize]| d[0], usize); extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize)); |