diff options
author | nicolas <nicolas@nicolass-MacBook-Pro.local> | 2023-12-12 17:41:56 +0100 |
---|---|---|
committer | nicolas <nicolas@nicolass-MacBook-Pro.local> | 2023-12-12 17:41:56 +0100 |
commit | 87dc559817db11f8d8c409cda959528e57e1db31 (patch) | |
tree | 3f7ec04a0facab3378158ae3ba84416d56fd37a7 /candle-transformers | |
parent | da0af3cb3e58d38476a20f4465744093a3b75dd4 (diff) | |
download | candle-87dc559817db11f8d8c409cda959528e57e1db31.tar.gz candle-87dc559817db11f8d8c409cda959528e57e1db31.tar.bz2 candle-87dc559817db11f8d8c409cda959528e57e1db31.zip |
Lots of updates including some stack of command buffers.
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/Cargo.toml | 1 | ||||
-rw-r--r-- | candle-transformers/src/models/mixformer.rs | 46 |
2 files changed, 43 insertions, 4 deletions
diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index af4e04b7..e72cab69 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -31,3 +31,4 @@ accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"] cuda = ["candle/cuda", "candle-nn/cuda"] flash-attn = ["cuda", "dep:candle-flash-attn"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"] +metal = ["candle/metal", "candle-nn/metal"] diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index e822ca14..c8dae511 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -142,10 +142,9 @@ impl RotaryEmbedding { .to_dtype(DType::F32)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; - Ok(Self { - sin: freqs.sin()?, - cos: freqs.cos()?, - }) + let sin = freqs.sin()?; + let cos = freqs.cos()?; + Ok(Self { sin, cos }) } fn apply_rotary_emb_qkv( @@ -273,6 +272,10 @@ impl MHA { } fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> { + let view = xs.to_string(); + if view.contains("NaN") { + panic!("NaN"); + } let _enter = self.span.enter(); let (b_size, seq_len, _n_embd) = xs.dims3()?; let qkv = self @@ -408,3 +411,38 @@ impl MixFormerSequentialForCausalLM { self.blocks.iter_mut().for_each(|b| b.clear_kv_cache()) } } + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_rotary() { + let dev = Device::new_metal(0).unwrap(); + for i in 0..10000 { + let dim = 8; + let max_seq_len = 12; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32)) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), &dev).unwrap(); + let t = Tensor::arange(0u32, max_seq_len as u32, &dev) + .unwrap() + .to_dtype(DType::F32) + .unwrap() + .reshape((max_seq_len, 1)) + .unwrap(); + let x: f32 = t.i((1, 0)).unwrap().to_scalar().unwrap(); + assert_eq!(x, 1.0); + let x: f32 = inv_freq.i((0, 1)).unwrap().to_scalar().unwrap(); + assert_eq!(x, 0.1); + let freqs = t.matmul(&inv_freq).unwrap(); + let x: f32 = freqs.i((1, 1)).unwrap().to_scalar().unwrap(); + assert_eq!(x, 0.1); + let sin = freqs.sin().unwrap().contiguous().unwrap(); + let x: f32 = sin.i((1, 1)).unwrap().to_scalar().unwrap(); + assert_eq!(x, 0.099833414); + } + } +} |