summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-26 08:28:33 +0100
committerGitHub <noreply@github.com>2023-07-26 08:28:33 +0100
commite40b150bbee980601f0a37ba4646216ee48bfbfb (patch)
treea4c3502fe71ed76481e76be843ccead7d9219ca1 /candle-examples/examples/llama
parent471855e2eec29ffd082dc3ea22157602baae3085 (diff)
downloadcandle-e40b150bbee980601f0a37ba4646216ee48bfbfb.tar.gz
candle-e40b150bbee980601f0a37ba4646216ee48bfbfb.tar.bz2
candle-e40b150bbee980601f0a37ba4646216ee48bfbfb.zip
Better handling of dtypes in llama. (#243)
Diffstat (limited to 'candle-examples/examples/llama')
-rw-r--r--candle-examples/examples/llama/main.rs2
-rw-r--r--candle-examples/examples/llama/model.rs23
2 files changed, 12 insertions, 13 deletions
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index 62bff0ae..cd82c518 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -128,8 +128,8 @@ fn main() -> Result<()> {
let device = candle_examples::device(args.cpu)?;
let config = Config::config_7b(args.use_flash_attn);
- let cache = model::Cache::new(!args.no_kv_cache, &config, &device)?;
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
+ let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let (llama, tokenizer_filename) = match args.npy {
Some(filename) => {
let vb = VarBuilder::from_npz(filename, dtype, &device)?;
diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs
index f82d0ac8..f2f2fe28 100644
--- a/candle-examples/examples/llama/model.rs
+++ b/candle-examples/examples/llama/model.rs
@@ -43,7 +43,7 @@ pub struct Cache {
}
impl Cache {
- pub fn new(use_kv_cache: bool, config: &Config, device: &Device) -> Result<Self> {
+ pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> {
// precompute freqs_cis
let n_elem = config.n_embd / config.n_head;
let theta: Vec<_> = (0..n_elem)
@@ -58,8 +58,8 @@ impl Cache {
// This is different from the paper, see:
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
- let cos = idx_theta.cos()?;
- let sin = idx_theta.sin()?;
+ let cos = idx_theta.cos()?.to_dtype(dtype)?;
+ let sin = idx_theta.sin()?.to_dtype(dtype)?;
Ok(Self {
masks: Arc::new(Mutex::new(HashMap::new())),
use_kv_cache,
@@ -170,7 +170,6 @@ impl CausalSelfAttention {
}
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
- let x_dtype = x.dtype();
let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
@@ -178,16 +177,13 @@ impl CausalSelfAttention {
let q = q
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
- .transpose(1, 2)?
- .to_dtype(DType::F32)?;
+ .transpose(1, 2)?;
let k = k
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
- .transpose(1, 2)?
- .to_dtype(DType::F32)?;
+ .transpose(1, 2)?;
let mut v = v
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
- .transpose(1, 2)?
- .to_dtype(DType::F32)?;
+ .transpose(1, 2)?;
let q = self.apply_rotary_emb(&q, index_pos)?;
let mut k = self.apply_rotary_emb(&k, index_pos)?;
@@ -219,15 +215,18 @@ impl CausalSelfAttention {
let y = if self.use_flash_attn {
flash_attn(&q, &k, &v)?
} else {
+ let in_dtype = q.dtype();
+ let q = q.to_dtype(DType::F32)?;
+ let k = k.to_dtype(DType::F32)?;
+ let v = v.to_dtype(DType::F32)?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = att.softmax(D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
- att.matmul(&v.contiguous()?)?
+ att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?
};
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
- let y = y.to_dtype(x_dtype)?;
let y = self.o_proj.forward(&y)?;
Ok(y)
}