diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-26 08:28:33 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-26 08:28:33 +0100 |
commit | e40b150bbee980601f0a37ba4646216ee48bfbfb (patch) | |
tree | a4c3502fe71ed76481e76be843ccead7d9219ca1 /candle-examples/examples/llama | |
parent | 471855e2eec29ffd082dc3ea22157602baae3085 (diff) | |
download | candle-e40b150bbee980601f0a37ba4646216ee48bfbfb.tar.gz candle-e40b150bbee980601f0a37ba4646216ee48bfbfb.tar.bz2 candle-e40b150bbee980601f0a37ba4646216ee48bfbfb.zip |
Better handling of dtypes in llama. (#243)
Diffstat (limited to 'candle-examples/examples/llama')
-rw-r--r-- | candle-examples/examples/llama/main.rs | 2 | ||||
-rw-r--r-- | candle-examples/examples/llama/model.rs | 23 |
2 files changed, 12 insertions, 13 deletions
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 62bff0ae..cd82c518 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -128,8 +128,8 @@ fn main() -> Result<()> { let device = candle_examples::device(args.cpu)?; let config = Config::config_7b(args.use_flash_attn); - let cache = model::Cache::new(!args.no_kv_cache, &config, &device)?; let dtype = if args.use_f32 { DType::F32 } else { DType::F16 }; + let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; let (llama, tokenizer_filename) = match args.npy { Some(filename) => { let vb = VarBuilder::from_npz(filename, dtype, &device)?; diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index f82d0ac8..f2f2fe28 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -43,7 +43,7 @@ pub struct Cache { } impl Cache { - pub fn new(use_kv_cache: bool, config: &Config, device: &Device) -> Result<Self> { + pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> { // precompute freqs_cis let n_elem = config.n_embd / config.n_head; let theta: Vec<_> = (0..n_elem) @@ -58,8 +58,8 @@ impl Cache { // This is different from the paper, see: // https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112 let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?; - let cos = idx_theta.cos()?; - let sin = idx_theta.sin()?; + let cos = idx_theta.cos()?.to_dtype(dtype)?; + let sin = idx_theta.sin()?.to_dtype(dtype)?; Ok(Self { masks: Arc::new(Mutex::new(HashMap::new())), use_kv_cache, @@ -170,7 +170,6 @@ impl CausalSelfAttention { } fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> { - let x_dtype = x.dtype(); let (b_sz, seq_len, n_embd) = x.dims3()?; let q = self.q_proj.forward(x)?; let k = self.k_proj.forward(x)?; @@ -178,16 +177,13 @@ impl CausalSelfAttention { let q = q .reshape((b_sz, seq_len, self.n_head, self.head_dim))? - .transpose(1, 2)? - .to_dtype(DType::F32)?; + .transpose(1, 2)?; let k = k .reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))? - .transpose(1, 2)? - .to_dtype(DType::F32)?; + .transpose(1, 2)?; let mut v = v .reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))? - .transpose(1, 2)? - .to_dtype(DType::F32)?; + .transpose(1, 2)?; let q = self.apply_rotary_emb(&q, index_pos)?; let mut k = self.apply_rotary_emb(&k, index_pos)?; @@ -219,15 +215,18 @@ impl CausalSelfAttention { let y = if self.use_flash_attn { flash_attn(&q, &k, &v)? } else { + let in_dtype = q.dtype(); + let q = q.to_dtype(DType::F32)?; + let k = k.to_dtype(DType::F32)?; + let v = v.to_dtype(DType::F32)?; let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; let att = att.softmax(D::Minus1)?; // Convert to contiguous as matmul doesn't support strided vs for now. - att.matmul(&v.contiguous()?)? + att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)? }; let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; - let y = y.to_dtype(x_dtype)?; let y = self.o_proj.forward(&y)?; Ok(y) } |