summaryrefslogtreecommitdiff
path: root/candle-core/tests
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-08-01 10:22:46 +0100
committerGitHub <noreply@github.com>2024-08-01 11:22:46 +0200
commitd4b6f6eef64d805e9fd678608378e1dfeb8278d2 (patch)
treef11f3dcd925ea6275f154302ef8e1cc8b8ead921 /candle-core/tests
parent957d604a7888bbf0243dbbca83a438db5132b48f (diff)
downloadcandle-d4b6f6eef64d805e9fd678608378e1dfeb8278d2.tar.gz
candle-d4b6f6eef64d805e9fd678608378e1dfeb8278d2.tar.bz2
candle-d4b6f6eef64d805e9fd678608378e1dfeb8278d2.zip
Add a minimal test for the metal bf16 matmul. (#2381)
Diffstat (limited to 'candle-core/tests')
-rw-r--r--candle-core/tests/matmul_tests.rs20
1 files changed, 20 insertions, 0 deletions
diff --git a/candle-core/tests/matmul_tests.rs b/candle-core/tests/matmul_tests.rs
index e3e18107..c1c16401 100644
--- a/candle-core/tests/matmul_tests.rs
+++ b/candle-core/tests/matmul_tests.rs
@@ -49,6 +49,20 @@ fn matmul(device: &Device) -> Result<()> {
Ok(())
}
+fn matmul_bf16(device: &Device) -> Result<()> {
+ if !device.supports_bf16() {
+ return Ok(());
+ }
+ let data = vec![1.0f32, 2.0, 3.0, 4.0];
+ let a = Tensor::from_slice(&data, (2, 2), device)?.to_dtype(DType::BF16)?;
+ let data = vec![1.0f32, 2.0, 3.0, 4.0];
+ let b = Tensor::from_slice(&data, (2, 2), device)?.to_dtype(DType::BF16)?;
+
+ let c = a.matmul(&b)?.to_dtype(DType::F32)?;
+ assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
+ Ok(())
+}
+
fn broadcast_matmul(device: &Device) -> Result<()> {
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
@@ -97,6 +111,12 @@ fn mm_layout(device: &Device) -> Result<()> {
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
test_device!(
+ matmul_bf16,
+ matmul_bf16_cpu,
+ matmul_bf16_gpu,
+ matmul_bf16_metal
+);
+test_device!(
broadcast_matmul,
broadcast_matmul_cpu,
broadcast_matmul_gpu,