summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authornicolas <nicolas@nicolass-MacBook-Pro.local>2023-12-12 17:41:56 +0100
committernicolas <nicolas@nicolass-MacBook-Pro.local>2023-12-12 17:41:56 +0100
commit87dc559817db11f8d8c409cda959528e57e1db31 (patch)
tree3f7ec04a0facab3378158ae3ba84416d56fd37a7 /candle-transformers
parentda0af3cb3e58d38476a20f4465744093a3b75dd4 (diff)
downloadcandle-87dc559817db11f8d8c409cda959528e57e1db31.tar.gz
candle-87dc559817db11f8d8c409cda959528e57e1db31.tar.bz2
candle-87dc559817db11f8d8c409cda959528e57e1db31.zip
Lots of updates including some stack of command buffers.
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/Cargo.toml1
-rw-r--r--candle-transformers/src/models/mixformer.rs46
2 files changed, 43 insertions, 4 deletions
diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml
index af4e04b7..e72cab69 100644
--- a/candle-transformers/Cargo.toml
+++ b/candle-transformers/Cargo.toml
@@ -31,3 +31,4 @@ accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda"]
flash-attn = ["cuda", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]
+metal = ["candle/metal", "candle-nn/metal"]
diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs
index e822ca14..c8dae511 100644
--- a/candle-transformers/src/models/mixformer.rs
+++ b/candle-transformers/src/models/mixformer.rs
@@ -142,10 +142,9 @@ impl RotaryEmbedding {
.to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
- Ok(Self {
- sin: freqs.sin()?,
- cos: freqs.cos()?,
- })
+ let sin = freqs.sin()?;
+ let cos = freqs.cos()?;
+ Ok(Self { sin, cos })
}
fn apply_rotary_emb_qkv(
@@ -273,6 +272,10 @@ impl MHA {
}
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
+ let view = xs.to_string();
+ if view.contains("NaN") {
+ panic!("NaN");
+ }
let _enter = self.span.enter();
let (b_size, seq_len, _n_embd) = xs.dims3()?;
let qkv = self
@@ -408,3 +411,38 @@ impl MixFormerSequentialForCausalLM {
self.blocks.iter_mut().for_each(|b| b.clear_kv_cache())
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ #[test]
+ fn test_rotary() {
+ let dev = Device::new_metal(0).unwrap();
+ for i in 0..10000 {
+ let dim = 8;
+ let max_seq_len = 12;
+ let inv_freq: Vec<_> = (0..dim)
+ .step_by(2)
+ .map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
+ .collect();
+ let inv_freq_len = inv_freq.len();
+ let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), &dev).unwrap();
+ let t = Tensor::arange(0u32, max_seq_len as u32, &dev)
+ .unwrap()
+ .to_dtype(DType::F32)
+ .unwrap()
+ .reshape((max_seq_len, 1))
+ .unwrap();
+ let x: f32 = t.i((1, 0)).unwrap().to_scalar().unwrap();
+ assert_eq!(x, 1.0);
+ let x: f32 = inv_freq.i((0, 1)).unwrap().to_scalar().unwrap();
+ assert_eq!(x, 0.1);
+ let freqs = t.matmul(&inv_freq).unwrap();
+ let x: f32 = freqs.i((1, 1)).unwrap().to_scalar().unwrap();
+ assert_eq!(x, 0.1);
+ let sin = freqs.sin().unwrap().contiguous().unwrap();
+ let x: f32 = sin.i((1, 1)).unwrap().to_scalar().unwrap();
+ assert_eq!(x, 0.099833414);
+ }
+ }
+}