diff options
Diffstat (limited to 'candle-examples/examples/llama_multiprocess/model.rs')
-rw-r--r-- | candle-examples/examples/llama_multiprocess/model.rs | 90 |
1 files changed, 28 insertions, 62 deletions
diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs index 230a2f1e..573eae11 100644 --- a/candle-examples/examples/llama_multiprocess/model.rs +++ b/candle-examples/examples/llama_multiprocess/model.rs @@ -3,7 +3,6 @@ use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shap use candle_nn::{Embedding, Linear, VarBuilder}; use cudarc::nccl::safe::{Comm, ReduceOp}; use half::f16; -use std::collections::HashMap; use std::rc::Rc; use std::sync::{Arc, Mutex}; @@ -137,17 +136,14 @@ impl Config { #[derive(Clone)] pub struct Cache { - masks: Arc<Mutex<HashMap<usize, Tensor>>>, - pub use_kv_cache: bool, #[allow(clippy::type_complexity)] kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>, cos: Tensor, sin: Tensor, - device: Device, } impl Cache { - pub fn new(use_kv_cache: bool, config: &Config, device: &Device) -> Result<Self> { + pub fn new(config: &Config, device: &Device) -> Result<Self> { // precompute freqs_cis let n_elem = config.n_embd / config.n_head; let theta: Vec<_> = (0..n_elem) @@ -162,31 +158,14 @@ impl Cache { // This is different from the paper, see: // https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112 let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?; - let cos = idx_theta.cos()?; - let sin = idx_theta.sin()?; + let cos = idx_theta.cos()?.to_dtype(DType::F16)?; + let sin = idx_theta.sin()?.to_dtype(DType::F16)?; Ok(Self { - masks: Arc::new(Mutex::new(HashMap::new())), - use_kv_cache, kvs: Arc::new(Mutex::new(vec![None; config.n_layer])), - device: device.clone(), cos, sin, }) } - - fn mask(&self, t: usize) -> Result<Tensor> { - let mut masks = self.masks.lock().unwrap(); - if let Some(mask) = masks.get(&t) { - Ok(mask.clone()) - } else { - let mask: Vec<_> = (0..t) - .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) - .collect(); - let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; - masks.insert(t, mask.clone()); - Ok(mask) - } - } } fn silu(xs: &Tensor) -> Result<Tensor> { @@ -260,7 +239,6 @@ impl CausalSelfAttention { } fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> { - let x_dtype = x.dtype(); let (b_sz, seq_len, _) = x.shape().dims3()?; let qkv = self.qkv_proj.forward(x)?; @@ -282,51 +260,46 @@ impl CausalSelfAttention { let q = q .reshape((b_sz, seq_len, self.n_head, self.head_dim))? - .transpose(1, 2)? - .to_dtype(DType::F32)?; + .transpose(1, 2)?; let k = k .reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))? - .transpose(1, 2)? - .to_dtype(DType::F32)?; + .transpose(1, 2)?; let mut v = v .reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))? - .transpose(1, 2)? - .to_dtype(DType::F32)?; + .transpose(1, 2)?; let q = self.apply_rotary_emb(&q, index_pos)?; let mut k = self.apply_rotary_emb(&k, index_pos)?; - if self.cache.use_kv_cache { - let mut cache = self.cache.kvs.lock().unwrap(); - if let Some((cache_k, cache_v)) = &cache[block_idx] { - k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?; - v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?; - let k_seq_len = k.dims()[1]; - if k_seq_len > MAX_SEQ_LEN { - k = k - .narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? - .contiguous()? - } - let v_seq_len = v.dims()[1]; - if v_seq_len > 2 * MAX_SEQ_LEN { - v = v - .narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? - .contiguous()? - } + let mut cache = self.cache.kvs.lock().unwrap(); + if let Some((cache_k, cache_v)) = &cache[block_idx] { + k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?; + v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?; + let k_seq_len = k.dims()[1]; + if k_seq_len > MAX_SEQ_LEN { + k = k + .narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? + .contiguous()? + } + let v_seq_len = v.dims()[1]; + if v_seq_len > 2 * MAX_SEQ_LEN { + v = v + .narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? + .contiguous()? } - cache[block_idx] = Some((k.clone(), v.clone())) } + cache[block_idx] = Some((k.clone(), v.clone())); let k = self.repeat_kv(k)?; let v = self.repeat_kv(v)?; - let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; - let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?; - let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; - let att = candle_nn::ops::softmax(&att, D::Minus1)?; + let q = q.transpose(1, 2)?; + let k = k.transpose(1, 2)?; + let v = v.transpose(1, 2)?; + let softmax_scale = 1f32 / (self.head_dim as f32).sqrt(); + let y = + candle_flash_attn::flash_attn(q, k, v, softmax_scale, seq_len > 1)?.transpose(1, 2)?; // Convert to contiguous as matmul doesn't support strided vs for now. - let y = att.matmul(&v.contiguous()?)?; let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; - let y = y.to_dtype(x_dtype)?; let y = self.o_proj.forward(&y)?; Ok(y) } @@ -363,13 +336,6 @@ impl CausalSelfAttention { } } -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { - let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; - let m = mask.where_cond(&on_true, on_false)?; - Ok(m) -} - struct Mlp { c_fc1: TensorParallelColumnLinear, c_fc2: TensorParallelColumnLinear, |