diff options
author | Anubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com> | 2024-10-01 22:41:59 +0530 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-01 19:11:59 +0200 |
commit | a2bcc227df64b22cfbc54b5f96c995bf3a38c7bc (patch) | |
tree | 79d10359ccc57c6ad31f05f4b0cd8a3513af04ef /candle-core/tests/tensor_tests.rs | |
parent | def4c6cdeef78e437846efcb46a23006f539dee4 (diff) | |
download | candle-a2bcc227df64b22cfbc54b5f96c995bf3a38c7bc.tar.gz candle-a2bcc227df64b22cfbc54b5f96c995bf3a38c7bc.tar.bz2 candle-a2bcc227df64b22cfbc54b5f96c995bf3a38c7bc.zip |
Efficient implementation of `Tensor::ones()` for `metal` (#2512)
* WIP: hopefully better const impl
* with GPU
* More tests on
* Reverting primitive for
* Incorporating review changes - added check elem count check in kerner, using for call strategy
* rustfmt ran
Diffstat (limited to 'candle-core/tests/tensor_tests.rs')
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 30 |
1 files changed, 30 insertions, 0 deletions
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 4a76035c..e0cea15c 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -29,6 +29,36 @@ fn ones(device: &Device) -> Result<()> { Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?, [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], ); + assert_eq!( + Tensor::ones((2, 3), DType::F16, device)?.to_vec2::<half::f16>()?, + [ + [ + half::f16::from_f32(1.0), + half::f16::from_f32(1.0), + half::f16::from_f32(1.0) + ], + [ + half::f16::from_f32(1.0), + half::f16::from_f32(1.0), + half::f16::from_f32(1.0) + ] + ], + ); + assert_eq!( + Tensor::ones((2, 3), DType::BF16, device)?.to_vec2::<half::bf16>()?, + [ + [ + half::bf16::from_f32(1.0), + half::bf16::from_f32(1.0), + half::bf16::from_f32(1.0) + ], + [ + half::bf16::from_f32(1.0), + half::bf16::from_f32(1.0), + half::bf16::from_f32(1.0) + ] + ], + ); Ok(()) } |