diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-08-01 10:22:46 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-01 11:22:46 +0200 |
commit | d4b6f6eef64d805e9fd678608378e1dfeb8278d2 (patch) | |
tree | f11f3dcd925ea6275f154302ef8e1cc8b8ead921 /candle-core/tests | |
parent | 957d604a7888bbf0243dbbca83a438db5132b48f (diff) | |
download | candle-d4b6f6eef64d805e9fd678608378e1dfeb8278d2.tar.gz candle-d4b6f6eef64d805e9fd678608378e1dfeb8278d2.tar.bz2 candle-d4b6f6eef64d805e9fd678608378e1dfeb8278d2.zip |
Add a minimal test for the metal bf16 matmul. (#2381)
Diffstat (limited to 'candle-core/tests')
-rw-r--r-- | candle-core/tests/matmul_tests.rs | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/candle-core/tests/matmul_tests.rs b/candle-core/tests/matmul_tests.rs index e3e18107..c1c16401 100644 --- a/candle-core/tests/matmul_tests.rs +++ b/candle-core/tests/matmul_tests.rs @@ -49,6 +49,20 @@ fn matmul(device: &Device) -> Result<()> { Ok(()) } +fn matmul_bf16(device: &Device) -> Result<()> { + if !device.supports_bf16() { + return Ok(()); + } + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let a = Tensor::from_slice(&data, (2, 2), device)?.to_dtype(DType::BF16)?; + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let b = Tensor::from_slice(&data, (2, 2), device)?.to_dtype(DType::BF16)?; + + let c = a.matmul(&b)?.to_dtype(DType::F32)?; + assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]); + Ok(()) +} + fn broadcast_matmul(device: &Device) -> Result<()> { let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?; let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?; @@ -97,6 +111,12 @@ fn mm_layout(device: &Device) -> Result<()> { test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal); test_device!( + matmul_bf16, + matmul_bf16_cpu, + matmul_bf16_gpu, + matmul_bf16_metal +); +test_device!( broadcast_matmul, broadcast_matmul_cpu, broadcast_matmul_gpu, |