summaryrefslogtreecommitdiff
path: root/candle-nn/tests/ops.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/tests/ops.rs')
-rw-r--r--candle-nn/tests/ops.rs32
1 files changed, 30 insertions, 2 deletions
diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs
index af883b85..20a66e75 100644
--- a/candle-nn/tests/ops.rs
+++ b/candle-nn/tests/ops.rs
@@ -86,7 +86,7 @@ fn softmax_numerical_stability() -> Result<()> {
Ok(())
}
-fn rope(device: &Device) -> Result<()> {
+fn ropei(device: &Device) -> Result<()> {
use rand::{rngs::StdRng, Rng, SeedableRng};
let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
@@ -107,12 +107,40 @@ fn rope(device: &Device) -> Result<()> {
let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
if device.is_cpu() {
assert_eq!(sum_diff, 0.);
- } else if device.is_cuda() {
+ } else {
+ assert!(sum_diff < 1e-4);
+ }
+ Ok(())
+}
+
+fn rope(device: &Device) -> Result<()> {
+ use rand::{rngs::StdRng, Rng, SeedableRng};
+
+ let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
+ let el_count = b_size * num_head * seq_len * head_dim;
+ let mut rng = StdRng::seed_from_u64(299792458);
+ let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
+ let cos: Vec<f32> = (0..seq_len * head_dim / 2)
+ .map(|_| rng.gen::<f32>())
+ .collect();
+ let sin: Vec<f32> = (0..seq_len * head_dim / 2)
+ .map(|_| rng.gen::<f32>())
+ .collect();
+ let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?;
+ let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?;
+ let sin = Tensor::from_vec(sin, (seq_len, head_dim / 2), device)?;
+ let rope1 = candle_nn::rotary_emb::rope(&src, &cos, &sin)?;
+ let rope2 = candle_nn::rotary_emb::rope_slow(&src, &cos, &sin)?;
+ let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
+ if device.is_cpu() {
+ assert_eq!(sum_diff, 0.);
+ } else {
assert!(sum_diff < 1e-4);
}
Ok(())
}
+test_device!(ropei, ropei_cpu, ropei_gpu, ropei_metal);
test_device!(rope, rope_cpu, rope_gpu, rope_metal);
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);