summaryrefslogtreecommitdiff
path: root/candle-core/tests/tensor_tests.rs
diff options
context:
space:
mode:
authorAnubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com>2024-10-01 22:41:59 +0530
committerGitHub <noreply@github.com>2024-10-01 19:11:59 +0200
commita2bcc227df64b22cfbc54b5f96c995bf3a38c7bc (patch)
tree79d10359ccc57c6ad31f05f4b0cd8a3513af04ef /candle-core/tests/tensor_tests.rs
parentdef4c6cdeef78e437846efcb46a23006f539dee4 (diff)
downloadcandle-a2bcc227df64b22cfbc54b5f96c995bf3a38c7bc.tar.gz
candle-a2bcc227df64b22cfbc54b5f96c995bf3a38c7bc.tar.bz2
candle-a2bcc227df64b22cfbc54b5f96c995bf3a38c7bc.zip
Efficient implementation of `Tensor::ones()` for `metal` (#2512)
* WIP: hopefully better const impl * with GPU * More tests on * Reverting primitive for * Incorporating review changes - added check elem count check in kerner, using for call strategy * rustfmt ran
Diffstat (limited to 'candle-core/tests/tensor_tests.rs')
-rw-r--r--candle-core/tests/tensor_tests.rs30
1 files changed, 30 insertions, 0 deletions
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index 4a76035c..e0cea15c 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -29,6 +29,36 @@ fn ones(device: &Device) -> Result<()> {
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
+ assert_eq!(
+ Tensor::ones((2, 3), DType::F16, device)?.to_vec2::<half::f16>()?,
+ [
+ [
+ half::f16::from_f32(1.0),
+ half::f16::from_f32(1.0),
+ half::f16::from_f32(1.0)
+ ],
+ [
+ half::f16::from_f32(1.0),
+ half::f16::from_f32(1.0),
+ half::f16::from_f32(1.0)
+ ]
+ ],
+ );
+ assert_eq!(
+ Tensor::ones((2, 3), DType::BF16, device)?.to_vec2::<half::bf16>()?,
+ [
+ [
+ half::bf16::from_f32(1.0),
+ half::bf16::from_f32(1.0),
+ half::bf16::from_f32(1.0)
+ ],
+ [
+ half::bf16::from_f32(1.0),
+ half::bf16::from_f32(1.0),
+ half::bf16::from_f32(1.0)
+ ]
+ ],
+ );
Ok(())
}