diff options
Diffstat (limited to 'candle-core/tests/tensor_tests.rs')
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 30 |
1 files changed, 30 insertions, 0 deletions
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 4a76035c..e0cea15c 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -29,6 +29,36 @@ fn ones(device: &Device) -> Result<()> { Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?, [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], ); + assert_eq!( + Tensor::ones((2, 3), DType::F16, device)?.to_vec2::<half::f16>()?, + [ + [ + half::f16::from_f32(1.0), + half::f16::from_f32(1.0), + half::f16::from_f32(1.0) + ], + [ + half::f16::from_f32(1.0), + half::f16::from_f32(1.0), + half::f16::from_f32(1.0) + ] + ], + ); + assert_eq!( + Tensor::ones((2, 3), DType::BF16, device)?.to_vec2::<half::bf16>()?, + [ + [ + half::bf16::from_f32(1.0), + half::bf16::from_f32(1.0), + half::bf16::from_f32(1.0) + ], + [ + half::bf16::from_f32(1.0), + half::bf16::from_f32(1.0), + half::bf16::from_f32(1.0) + ] + ], + ); Ok(()) } |