diff options
-rw-r--r-- | candle-examples/Cargo.toml | 2 | ||||
-rw-r--r-- | candle-examples/examples/llama_multiprocess/main.rs | 39 | ||||
-rw-r--r-- | candle-examples/examples/llama_multiprocess/model.rs | 90 |
3 files changed, 39 insertions, 92 deletions
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index a595e74d..5a04130f 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -46,4 +46,4 @@ nccl = ["cuda", "cudarc/nccl", "dep:half"] [[example]] name = "llama_multiprocess" -required-features = ["cuda", "nccl"] +required-features = ["cuda", "nccl", "flash-attn"] diff --git a/candle-examples/examples/llama_multiprocess/main.rs b/candle-examples/examples/llama_multiprocess/main.rs index f9e87432..679e5faa 100644 --- a/candle-examples/examples/llama_multiprocess/main.rs +++ b/candle-examples/examples/llama_multiprocess/main.rs @@ -20,7 +20,7 @@ use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; use cudarc::driver::safe::CudaDevice; use cudarc::nccl::safe::{Comm, Id}; -use hf_hub::{api::sync::Api, Repo, RepoType}; +use hf_hub::api::sync::Api; use std::io::Write; use std::rc::Rc; @@ -83,10 +83,6 @@ Upon my target three fair-shining suns. #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { - /// Run on CPU rather than on GPU. - #[arg(long)] - cpu: bool, - #[arg(long)] num_shards: usize, @@ -113,15 +109,8 @@ struct Args { #[arg(long)] prompt: Option<String>, - /// Use f32 computations rather than f16. - #[arg(long)] - use_f32: bool, - #[arg(long)] model_id: Option<String>, - - #[arg(long)] - v2: bool, } fn main() -> Result<()> { @@ -130,26 +119,22 @@ fn main() -> Result<()> { let args = Args::parse(); let config = Config::config_7b(); - let dtype = if args.use_f32 { DType::F32 } else { DType::F16 }; + let dtype = DType::F16; let api = Api::new()?; - let model_id = args.model_id.unwrap_or_else(|| { - if args.v2 { - "meta-llama/Llama-2-7b-hf".to_string() - } else { - "Narsil/amall-7b".to_string() - } - }); + let model_id = args + .model_id + .unwrap_or_else(|| "meta-llama/Llama-2-7b-hf".to_string()); println!("loading the model weights from {model_id}"); - let repo = Repo::new(model_id, RepoType::Model); - let tokenizer_filename = api.get(&repo, "tokenizer.json")?; + let api = api.model(model_id); + let tokenizer_filename = api.get("tokenizer.json")?; let mut filenames = vec![]; for rfilename in [ "model-00001-of-00002.safetensors", "model-00002-of-00002.safetensors", ] { - let filename = api.get(&repo, rfilename)?; + let filename = api.get(rfilename)?; filenames.push(filename); } @@ -203,7 +188,7 @@ fn main() -> Result<()> { println!("Rank {rank:?} spawned"); let device = Device::new_cuda(i)?; - let cache = model::Cache::new(!args.no_kv_cache, &config, &device)?; + let cache = model::Cache::new(&config, &device)?; println!("building the model"); let handles = filenames @@ -233,11 +218,7 @@ fn main() -> Result<()> { let mut index_pos = 0; for index in 0..args.sample_len { let start_gen = std::time::Instant::now(); - let context_size = if cache.use_kv_cache && index > 0 { - 1 - } else { - tokens.len() - }; + let context_size = if index > 0 { 1 } else { tokens.len() }; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; let logits = llama.forward(&input, index_pos)?; diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs index 230a2f1e..573eae11 100644 --- a/candle-examples/examples/llama_multiprocess/model.rs +++ b/candle-examples/examples/llama_multiprocess/model.rs @@ -3,7 +3,6 @@ use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shap use candle_nn::{Embedding, Linear, VarBuilder}; use cudarc::nccl::safe::{Comm, ReduceOp}; use half::f16; -use std::collections::HashMap; use std::rc::Rc; use std::sync::{Arc, Mutex}; @@ -137,17 +136,14 @@ impl Config { #[derive(Clone)] pub struct Cache { - masks: Arc<Mutex<HashMap<usize, Tensor>>>, - pub use_kv_cache: bool, #[allow(clippy::type_complexity)] kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>, cos: Tensor, sin: Tensor, - device: Device, } impl Cache { - pub fn new(use_kv_cache: bool, config: &Config, device: &Device) -> Result<Self> { + pub fn new(config: &Config, device: &Device) -> Result<Self> { // precompute freqs_cis let n_elem = config.n_embd / config.n_head; let theta: Vec<_> = (0..n_elem) @@ -162,31 +158,14 @@ impl Cache { // This is different from the paper, see: // https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112 let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?; - let cos = idx_theta.cos()?; - let sin = idx_theta.sin()?; + let cos = idx_theta.cos()?.to_dtype(DType::F16)?; + let sin = idx_theta.sin()?.to_dtype(DType::F16)?; Ok(Self { - masks: Arc::new(Mutex::new(HashMap::new())), - use_kv_cache, kvs: Arc::new(Mutex::new(vec![None; config.n_layer])), - device: device.clone(), cos, sin, }) } - - fn mask(&self, t: usize) -> Result<Tensor> { - let mut masks = self.masks.lock().unwrap(); - if let Some(mask) = masks.get(&t) { - Ok(mask.clone()) - } else { - let mask: Vec<_> = (0..t) - .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) - .collect(); - let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; - masks.insert(t, mask.clone()); - Ok(mask) - } - } } fn silu(xs: &Tensor) -> Result<Tensor> { @@ -260,7 +239,6 @@ impl CausalSelfAttention { } fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> { - let x_dtype = x.dtype(); let (b_sz, seq_len, _) = x.shape().dims3()?; let qkv = self.qkv_proj.forward(x)?; @@ -282,51 +260,46 @@ impl CausalSelfAttention { let q = q .reshape((b_sz, seq_len, self.n_head, self.head_dim))? - .transpose(1, 2)? - .to_dtype(DType::F32)?; + .transpose(1, 2)?; let k = k .reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))? - .transpose(1, 2)? - .to_dtype(DType::F32)?; + .transpose(1, 2)?; let mut v = v .reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))? - .transpose(1, 2)? - .to_dtype(DType::F32)?; + .transpose(1, 2)?; let q = self.apply_rotary_emb(&q, index_pos)?; let mut k = self.apply_rotary_emb(&k, index_pos)?; - if self.cache.use_kv_cache { - let mut cache = self.cache.kvs.lock().unwrap(); - if let Some((cache_k, cache_v)) = &cache[block_idx] { - k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?; - v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?; - let k_seq_len = k.dims()[1]; - if k_seq_len > MAX_SEQ_LEN { - k = k - .narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? - .contiguous()? - } - let v_seq_len = v.dims()[1]; - if v_seq_len > 2 * MAX_SEQ_LEN { - v = v - .narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? - .contiguous()? - } + let mut cache = self.cache.kvs.lock().unwrap(); + if let Some((cache_k, cache_v)) = &cache[block_idx] { + k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?; + v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?; + let k_seq_len = k.dims()[1]; + if k_seq_len > MAX_SEQ_LEN { + k = k + .narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? + .contiguous()? + } + let v_seq_len = v.dims()[1]; + if v_seq_len > 2 * MAX_SEQ_LEN { + v = v + .narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? + .contiguous()? } - cache[block_idx] = Some((k.clone(), v.clone())) } + cache[block_idx] = Some((k.clone(), v.clone())); let k = self.repeat_kv(k)?; let v = self.repeat_kv(v)?; - let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; - let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?; - let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; - let att = candle_nn::ops::softmax(&att, D::Minus1)?; + let q = q.transpose(1, 2)?; + let k = k.transpose(1, 2)?; + let v = v.transpose(1, 2)?; + let softmax_scale = 1f32 / (self.head_dim as f32).sqrt(); + let y = + candle_flash_attn::flash_attn(q, k, v, softmax_scale, seq_len > 1)?.transpose(1, 2)?; // Convert to contiguous as matmul doesn't support strided vs for now. - let y = att.matmul(&v.contiguous()?)?; let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; - let y = y.to_dtype(x_dtype)?; let y = self.o_proj.forward(&y)?; Ok(y) } @@ -363,13 +336,6 @@ impl CausalSelfAttention { } } -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { - let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; - let m = mask.where_cond(&on_true, on_false)?; - Ok(m) -} - struct Mlp { c_fc1: TensorParallelColumnLinear, c_fc2: TensorParallelColumnLinear, |