summaryrefslogtreecommitdiff
path: root/candle-core/tests/tensor_tests.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/tests/tensor_tests.rs')
-rw-r--r--candle-core/tests/tensor_tests.rs27
1 files changed, 27 insertions, 0 deletions
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index c8b255dd..f565972a 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -1169,3 +1169,30 @@ fn tril_triu_eye() -> Result<()> {
);
Ok(())
}
+
+#[test]
+fn cumsum() -> Result<()> {
+ let t = &[3f32, 1., 4., 1., 5.];
+ let t = Tensor::new(t, &Device::Cpu)?;
+ assert_eq!(t.cumsum(0)?.to_vec1::<f32>()?, [3., 4., 8., 9., 14.]);
+ let t = t.unsqueeze(1)?;
+ assert_eq!(
+ t.cumsum(0)?.to_vec2::<f32>()?,
+ [[3.0], [4.0], [8.0], [9.0], [14.0]]
+ );
+ assert_eq!(
+ t.cumsum(1)?.to_vec2::<f32>()?,
+ [[3.0], [1.0], [4.0], [1.0], [5.0]]
+ );
+ let t = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
+ let t = Tensor::new(t, &Device::Cpu)?;
+ assert_eq!(
+ t.cumsum(1)?.to_vec2::<f32>()?,
+ [[3.0, 4.0, 8.0, 9.0, 14.0], [2.0, 3.0, 10.0, 18.0, 20.0]],
+ );
+ assert_eq!(
+ t.cumsum(0)?.to_vec2::<f32>()?,
+ [[3.0, 1.0, 4.0, 1.0, 5.0], [5.0, 2.0, 11.0, 9.0, 7.0]]
+ );
+ Ok(())
+}