diff options
Diffstat (limited to 'candle-nn/tests/ops.rs')
-rw-r--r-- | candle-nn/tests/ops.rs | 34 |
1 files changed, 30 insertions, 4 deletions
diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs index 5ca01b37..c1e3031f 100644 --- a/candle-nn/tests/ops.rs +++ b/candle-nn/tests/ops.rs @@ -4,11 +4,9 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -use candle::{test_utils::to_vec3_round, Device, Result, Tensor}; +use candle::{test_device, test_utils::to_vec3_round, Device, Result, Tensor}; -#[test] -fn softmax() -> Result<()> { - let device = &Device::Cpu; +fn softmax(device: &Device) -> Result<()> { let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; let tensor = Tensor::new(data, device)?; let t0 = candle_nn::ops::softmax(&tensor.log()?, 0)?; @@ -54,6 +52,31 @@ fn softmax() -> Result<()> { Ok(()) } +fn rms_norm(device: &Device) -> Result<()> { + let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; + let tensor = Tensor::new(data, device)?; + let alpha = Tensor::new(&[1f32, 2f32, 3f32], device)?; + let t = candle_nn::ops::rms_norm(&tensor, &alpha, 1e-5)?; + assert_eq!( + to_vec3_round(&t, 4)?, + &[ + [[1.019, 0.6794, 4.0762], [0.1674, 1.6744, 4.521]], + [[0.4714, 0.4714, 4.9497], [1.206, 0.603, 3.6181]] + ] + ); + let t2 = candle_nn::ops::rms_norm_slow(&tensor, &alpha, 1e-5)?; + assert_eq!( + to_vec3_round(&t2, 4)?, + &[ + [[1.019, 0.6794, 4.0762], [0.1674, 1.6744, 4.521]], + [[0.4714, 0.4714, 4.9497], [1.206, 0.603, 3.6181]] + ] + ); + let diff = (t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?; + assert!(diff < 1e-5); + Ok(()) +} + #[test] fn softmax_numerical_stability() -> Result<()> { let dev = &Device::Cpu; @@ -62,3 +85,6 @@ fn softmax_numerical_stability() -> Result<()> { assert_eq!(softmax.to_vec1::<f32>()?, &[1f32, 0.]); Ok(()) } + +test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal); +test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal); |