summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-02-27 20:32:30 +0100
committerGitHub <noreply@github.com>2024-02-27 20:32:30 +0100
commit205767f9ded3d531822d3702442a52b4a320f72e (patch)
tree1d03f2d99e63992a99ed42ae6fd50bdfe4f8dd0a
parent5e526abc8c0ecad2bd110a34d128e1e6d5333c68 (diff)
downloadcandle-205767f9ded3d531822d3702442a52b4a320f72e.tar.gz
candle-205767f9ded3d531822d3702442a52b4a320f72e.tar.bz2
candle-205767f9ded3d531822d3702442a52b4a320f72e.zip
Avoid tensor copying in the quantized example. (#1770)
-rw-r--r--candle-transformers/src/models/quantized_llama.rs12
1 files changed, 8 insertions, 4 deletions
diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs
index eb4136f6..94324149 100644
--- a/candle-transformers/src/models/quantized_llama.rs
+++ b/candle-transformers/src/models/quantized_llama.rs
@@ -157,16 +157,16 @@ struct LayerWeights {
head_dim: usize,
cos: Tensor,
sin: Tensor,
+ neg_inf: Tensor,
kv_cache: Option<(Tensor, Tensor)>,
span_attn: tracing::Span,
span_rot: tracing::Span,
span_mlp: tracing::Span,
}
-fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
+fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Tensor> {
let shape = mask.shape();
- let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
- let m = mask.where_cond(&on_true, on_false)?;
+ let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?;
Ok(m)
}
@@ -240,7 +240,7 @@ impl LayerWeights {
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = mask.broadcast_as(att.shape())?;
- let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
+ let att = masked_fill(&att, &mask, &self.neg_inf)?;
let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
@@ -298,6 +298,7 @@ impl ModelWeights {
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
let (cos, sin) = precomput_freqs_cis(head_dim, 10000., &ct.device)?;
+ let neg_inf = Tensor::new(f32::NEG_INFINITY, &ct.device)?;
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
@@ -337,6 +338,7 @@ impl ModelWeights {
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
cos: cos.clone(),
sin: sin.clone(),
+ neg_inf: neg_inf.clone(),
kv_cache: None,
span_attn,
span_rot,
@@ -385,6 +387,7 @@ impl ModelWeights {
.and_then(|m| m.to_f32())
.unwrap_or(10000f32);
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
+ let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(device)?;
@@ -455,6 +458,7 @@ impl ModelWeights {
head_dim: embedding_length / head_count,
cos: cos.clone(),
sin: sin.clone(),
+ neg_inf: neg_inf.clone(),
kv_cache: None,
span_attn,
span_rot,