diff options
Diffstat (limited to 'candle-nn/tests/ops.rs')
-rw-r--r-- | candle-nn/tests/ops.rs | 32 |
1 files changed, 30 insertions, 2 deletions
diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs index af883b85..20a66e75 100644 --- a/candle-nn/tests/ops.rs +++ b/candle-nn/tests/ops.rs @@ -86,7 +86,7 @@ fn softmax_numerical_stability() -> Result<()> { Ok(()) } -fn rope(device: &Device) -> Result<()> { +fn ropei(device: &Device) -> Result<()> { use rand::{rngs::StdRng, Rng, SeedableRng}; let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16); @@ -107,12 +107,40 @@ fn rope(device: &Device) -> Result<()> { let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::<f32>()?; if device.is_cpu() { assert_eq!(sum_diff, 0.); - } else if device.is_cuda() { + } else { + assert!(sum_diff < 1e-4); + } + Ok(()) +} + +fn rope(device: &Device) -> Result<()> { + use rand::{rngs::StdRng, Rng, SeedableRng}; + + let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16); + let el_count = b_size * num_head * seq_len * head_dim; + let mut rng = StdRng::seed_from_u64(299792458); + let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect(); + let cos: Vec<f32> = (0..seq_len * head_dim / 2) + .map(|_| rng.gen::<f32>()) + .collect(); + let sin: Vec<f32> = (0..seq_len * head_dim / 2) + .map(|_| rng.gen::<f32>()) + .collect(); + let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?; + let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?; + let sin = Tensor::from_vec(sin, (seq_len, head_dim / 2), device)?; + let rope1 = candle_nn::rotary_emb::rope(&src, &cos, &sin)?; + let rope2 = candle_nn::rotary_emb::rope_slow(&src, &cos, &sin)?; + let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::<f32>()?; + if device.is_cpu() { + assert_eq!(sum_diff, 0.); + } else { assert!(sum_diff < 1e-4); } Ok(()) } +test_device!(ropei, ropei_cpu, ropei_gpu, ropei_metal); test_device!(rope, rope_cpu, rope_gpu, rope_metal); test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal); test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal); |