summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama_multiprocess/model.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/llama_multiprocess/model.rs')
-rw-r--r--candle-examples/examples/llama_multiprocess/model.rs90
1 files changed, 28 insertions, 62 deletions
diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs
index 230a2f1e..573eae11 100644
--- a/candle-examples/examples/llama_multiprocess/model.rs
+++ b/candle-examples/examples/llama_multiprocess/model.rs
@@ -3,7 +3,6 @@ use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shap
use candle_nn::{Embedding, Linear, VarBuilder};
use cudarc::nccl::safe::{Comm, ReduceOp};
use half::f16;
-use std::collections::HashMap;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
@@ -137,17 +136,14 @@ impl Config {
#[derive(Clone)]
pub struct Cache {
- masks: Arc<Mutex<HashMap<usize, Tensor>>>,
- pub use_kv_cache: bool,
#[allow(clippy::type_complexity)]
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
cos: Tensor,
sin: Tensor,
- device: Device,
}
impl Cache {
- pub fn new(use_kv_cache: bool, config: &Config, device: &Device) -> Result<Self> {
+ pub fn new(config: &Config, device: &Device) -> Result<Self> {
// precompute freqs_cis
let n_elem = config.n_embd / config.n_head;
let theta: Vec<_> = (0..n_elem)
@@ -162,31 +158,14 @@ impl Cache {
// This is different from the paper, see:
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
- let cos = idx_theta.cos()?;
- let sin = idx_theta.sin()?;
+ let cos = idx_theta.cos()?.to_dtype(DType::F16)?;
+ let sin = idx_theta.sin()?.to_dtype(DType::F16)?;
Ok(Self {
- masks: Arc::new(Mutex::new(HashMap::new())),
- use_kv_cache,
kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
- device: device.clone(),
cos,
sin,
})
}
-
- fn mask(&self, t: usize) -> Result<Tensor> {
- let mut masks = self.masks.lock().unwrap();
- if let Some(mask) = masks.get(&t) {
- Ok(mask.clone())
- } else {
- let mask: Vec<_> = (0..t)
- .flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
- .collect();
- let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
- masks.insert(t, mask.clone());
- Ok(mask)
- }
- }
}
fn silu(xs: &Tensor) -> Result<Tensor> {
@@ -260,7 +239,6 @@ impl CausalSelfAttention {
}
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
- let x_dtype = x.dtype();
let (b_sz, seq_len, _) = x.shape().dims3()?;
let qkv = self.qkv_proj.forward(x)?;
@@ -282,51 +260,46 @@ impl CausalSelfAttention {
let q = q
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
- .transpose(1, 2)?
- .to_dtype(DType::F32)?;
+ .transpose(1, 2)?;
let k = k
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
- .transpose(1, 2)?
- .to_dtype(DType::F32)?;
+ .transpose(1, 2)?;
let mut v = v
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
- .transpose(1, 2)?
- .to_dtype(DType::F32)?;
+ .transpose(1, 2)?;
let q = self.apply_rotary_emb(&q, index_pos)?;
let mut k = self.apply_rotary_emb(&k, index_pos)?;
- if self.cache.use_kv_cache {
- let mut cache = self.cache.kvs.lock().unwrap();
- if let Some((cache_k, cache_v)) = &cache[block_idx] {
- k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
- v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
- let k_seq_len = k.dims()[1];
- if k_seq_len > MAX_SEQ_LEN {
- k = k
- .narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
- .contiguous()?
- }
- let v_seq_len = v.dims()[1];
- if v_seq_len > 2 * MAX_SEQ_LEN {
- v = v
- .narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
- .contiguous()?
- }
+ let mut cache = self.cache.kvs.lock().unwrap();
+ if let Some((cache_k, cache_v)) = &cache[block_idx] {
+ k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
+ v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
+ let k_seq_len = k.dims()[1];
+ if k_seq_len > MAX_SEQ_LEN {
+ k = k
+ .narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
+ .contiguous()?
+ }
+ let v_seq_len = v.dims()[1];
+ if v_seq_len > 2 * MAX_SEQ_LEN {
+ v = v
+ .narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
+ .contiguous()?
}
- cache[block_idx] = Some((k.clone(), v.clone()))
}
+ cache[block_idx] = Some((k.clone(), v.clone()));
let k = self.repeat_kv(k)?;
let v = self.repeat_kv(v)?;
- let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
- let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
- let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
- let att = candle_nn::ops::softmax(&att, D::Minus1)?;
+ let q = q.transpose(1, 2)?;
+ let k = k.transpose(1, 2)?;
+ let v = v.transpose(1, 2)?;
+ let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
+ let y =
+ candle_flash_attn::flash_attn(q, k, v, softmax_scale, seq_len > 1)?.transpose(1, 2)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
- let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
- let y = y.to_dtype(x_dtype)?;
let y = self.o_proj.forward(&y)?;
Ok(y)
}
@@ -363,13 +336,6 @@ impl CausalSelfAttention {
}
}
-fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
- let shape = mask.shape();
- let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
- let m = mask.where_cond(&on_true, on_false)?;
- Ok(m)
-}
-
struct Mlp {
c_fc1: TensorParallelColumnLinear,
c_fc2: TensorParallelColumnLinear,