summaryrefslogtreecommitdiff
path: root/candle-core/tests/tensor_tests.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/tests/tensor_tests.rs')
-rw-r--r--candle-core/tests/tensor_tests.rs30
1 files changed, 30 insertions, 0 deletions
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index 4a76035c..e0cea15c 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -29,6 +29,36 @@ fn ones(device: &Device) -> Result<()> {
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
+ assert_eq!(
+ Tensor::ones((2, 3), DType::F16, device)?.to_vec2::<half::f16>()?,
+ [
+ [
+ half::f16::from_f32(1.0),
+ half::f16::from_f32(1.0),
+ half::f16::from_f32(1.0)
+ ],
+ [
+ half::f16::from_f32(1.0),
+ half::f16::from_f32(1.0),
+ half::f16::from_f32(1.0)
+ ]
+ ],
+ );
+ assert_eq!(
+ Tensor::ones((2, 3), DType::BF16, device)?.to_vec2::<half::bf16>()?,
+ [
+ [
+ half::bf16::from_f32(1.0),
+ half::bf16::from_f32(1.0),
+ half::bf16::from_f32(1.0)
+ ],
+ [
+ half::bf16::from_f32(1.0),
+ half::bf16::from_f32(1.0),
+ half::bf16::from_f32(1.0)
+ ]
+ ],
+ );
Ok(())
}