diff options
Diffstat (limited to 'candle-nn/tests/ops.rs')
-rw-r--r-- | candle-nn/tests/ops.rs | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs index 20a66e75..24a49d06 100644 --- a/candle-nn/tests/ops.rs +++ b/candle-nn/tests/ops.rs @@ -140,7 +140,38 @@ fn rope(device: &Device) -> Result<()> { Ok(()) } +fn rope_thd(device: &Device) -> Result<()> { + use rand::{rngs::StdRng, Rng, SeedableRng}; + + let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16); + let el_count = b_size * num_head * seq_len * head_dim; + let mut rng = StdRng::seed_from_u64(299792458); + let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect(); + let cos: Vec<f32> = (0..seq_len * head_dim / 2) + .map(|_| rng.gen::<f32>()) + .collect(); + let sin: Vec<f32> = (0..seq_len * head_dim / 2) + .map(|_| rng.gen::<f32>()) + .collect(); + let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?; + let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?; + let sin = Tensor::from_vec(sin, (seq_len, head_dim / 2), device)?; + let rope1 = { + let src = src.transpose(1, 2)?.contiguous()?; + candle_nn::rotary_emb::rope_thd(&src, &cos, &sin)?.transpose(1, 2)? + }; + let rope2 = candle_nn::rotary_emb::rope_slow(&src, &cos, &sin)?; + let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::<f32>()?; + if device.is_cpu() { + assert_eq!(sum_diff, 0.); + } else { + assert!(sum_diff < 1e-4); + } + Ok(()) +} + test_device!(ropei, ropei_cpu, ropei_gpu, ropei_metal); test_device!(rope, rope_cpu, rope_gpu, rope_metal); +test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal); test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal); test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal); |