diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-02-27 20:32:30 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-27 20:32:30 +0100 |
commit | 205767f9ded3d531822d3702442a52b4a320f72e (patch) | |
tree | 1d03f2d99e63992a99ed42ae6fd50bdfe4f8dd0a | |
parent | 5e526abc8c0ecad2bd110a34d128e1e6d5333c68 (diff) | |
download | candle-205767f9ded3d531822d3702442a52b4a320f72e.tar.gz candle-205767f9ded3d531822d3702442a52b4a320f72e.tar.bz2 candle-205767f9ded3d531822d3702442a52b4a320f72e.zip |
Avoid tensor copying in the quantized example. (#1770)
-rw-r--r-- | candle-transformers/src/models/quantized_llama.rs | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index eb4136f6..94324149 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -157,16 +157,16 @@ struct LayerWeights { head_dim: usize, cos: Tensor, sin: Tensor, + neg_inf: Tensor, kv_cache: Option<(Tensor, Tensor)>, span_attn: tracing::Span, span_rot: tracing::Span, span_mlp: tracing::Span, } -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Tensor> { let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; - let m = mask.where_cond(&on_true, on_false)?; + let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?; Ok(m) } @@ -240,7 +240,7 @@ impl LayerWeights { let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; let mask = mask.broadcast_as(att.shape())?; - let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; + let att = masked_fill(&att, &mask, &self.neg_inf)?; let att = candle_nn::ops::softmax_last_dim(&att)?; // Convert to contiguous as matmul doesn't support strided vs for now. let y = att.matmul(&v.contiguous()?)?; @@ -298,6 +298,7 @@ impl ModelWeights { pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> { let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; let (cos, sin) = precomput_freqs_cis(head_dim, 10000., &ct.device)?; + let neg_inf = Tensor::new(f32::NEG_INFINITY, &ct.device)?; let tok_embeddings = ct.remove("tok_embeddings.weight")?; let tok_embeddings = tok_embeddings.dequantize(&ct.device)?; let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?; @@ -337,6 +338,7 @@ impl ModelWeights { head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize, cos: cos.clone(), sin: sin.clone(), + neg_inf: neg_inf.clone(), kv_cache: None, span_attn, span_rot, @@ -385,6 +387,7 @@ impl ModelWeights { .and_then(|m| m.to_f32()) .unwrap_or(10000f32); let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?; + let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?; let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; let tok_embeddings = tok_embeddings.dequantize(device)?; @@ -455,6 +458,7 @@ impl ModelWeights { head_dim: embedding_length / head_count, cos: cos.clone(), sin: sin.clone(), + neg_inf: neg_inf.clone(), kv_cache: None, span_attn, span_rot, |