summaryrefslogtreecommitdiff
path: root/candle-nn/tests/ops.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/tests/ops.rs')
-rw-r--r--candle-nn/tests/ops.rs34
1 files changed, 30 insertions, 4 deletions
diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs
index 5ca01b37..c1e3031f 100644
--- a/candle-nn/tests/ops.rs
+++ b/candle-nn/tests/ops.rs
@@ -4,11 +4,9 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
-use candle::{test_utils::to_vec3_round, Device, Result, Tensor};
+use candle::{test_device, test_utils::to_vec3_round, Device, Result, Tensor};
-#[test]
-fn softmax() -> Result<()> {
- let device = &Device::Cpu;
+fn softmax(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
let tensor = Tensor::new(data, device)?;
let t0 = candle_nn::ops::softmax(&tensor.log()?, 0)?;
@@ -54,6 +52,31 @@ fn softmax() -> Result<()> {
Ok(())
}
+fn rms_norm(device: &Device) -> Result<()> {
+ let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
+ let tensor = Tensor::new(data, device)?;
+ let alpha = Tensor::new(&[1f32, 2f32, 3f32], device)?;
+ let t = candle_nn::ops::rms_norm(&tensor, &alpha, 1e-5)?;
+ assert_eq!(
+ to_vec3_round(&t, 4)?,
+ &[
+ [[1.019, 0.6794, 4.0762], [0.1674, 1.6744, 4.521]],
+ [[0.4714, 0.4714, 4.9497], [1.206, 0.603, 3.6181]]
+ ]
+ );
+ let t2 = candle_nn::ops::rms_norm_slow(&tensor, &alpha, 1e-5)?;
+ assert_eq!(
+ to_vec3_round(&t2, 4)?,
+ &[
+ [[1.019, 0.6794, 4.0762], [0.1674, 1.6744, 4.521]],
+ [[0.4714, 0.4714, 4.9497], [1.206, 0.603, 3.6181]]
+ ]
+ );
+ let diff = (t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
+ assert!(diff < 1e-5);
+ Ok(())
+}
+
#[test]
fn softmax_numerical_stability() -> Result<()> {
let dev = &Device::Cpu;
@@ -62,3 +85,6 @@ fn softmax_numerical_stability() -> Result<()> {
assert_eq!(softmax.to_vec1::<f32>()?, &[1f32, 0.]);
Ok(())
}
+
+test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
+test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);