summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-03-30 13:22:00 +0100
committerGitHub <noreply@github.com>2024-03-30 13:22:00 +0100
commitb190fd85920dfeb93c091593d42fda596c3a83a7 (patch)
tree52003bc7540277e37f006198e77a145fbf8dccd4
parentefe4a0c84b55b60f7555a89ea7e0ba8d300104cd (diff)
downloadcandle-b190fd85920dfeb93c091593d42fda596c3a83a7.tar.gz
candle-b190fd85920dfeb93c091593d42fda596c3a83a7.tar.bz2
candle-b190fd85920dfeb93c091593d42fda596c3a83a7.zip
Remove some unnecessary calls to contiguous. (#1968)
* Remove some unnecessary calls to contiguous. * Slightly improved kv cache concatenation.
-rw-r--r--candle-core/src/tensor_cat.rs22
-rw-r--r--candle-transformers/src/models/quantized_llama.rs14
2 files changed, 20 insertions, 16 deletions
diff --git a/candle-core/src/tensor_cat.rs b/candle-core/src/tensor_cat.rs
index 31cc8503..27ff7851 100644
--- a/candle-core/src/tensor_cat.rs
+++ b/candle-core/src/tensor_cat.rs
@@ -58,20 +58,18 @@ impl Tensor {
}
}
}
- if dim == 0 {
+ let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
+ if all_contiguous {
+ Self::cat_contiguous(args, dim)
+ } else if dim == 0 {
Self::cat0(args)
} else {
- let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
- if all_contiguous {
- Self::cat_contiguous(args, dim)
- } else {
- let args: Vec<Tensor> = args
- .iter()
- .map(|a| a.as_ref().transpose(0, dim))
- .collect::<Result<Vec<_>>>()?;
- let cat = Self::cat0(&args)?;
- cat.transpose(0, dim)
- }
+ let args: Vec<Tensor> = args
+ .iter()
+ .map(|a| a.as_ref().transpose(0, dim))
+ .collect::<Result<Vec<_>>>()?;
+ let cat = Self::cat0(&args)?;
+ cat.transpose(0, dim)
}
}
diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs
index 9898d872..e1519b2d 100644
--- a/candle-transformers/src/models/quantized_llama.rs
+++ b/candle-transformers/src/models/quantized_llama.rs
@@ -157,6 +157,8 @@ impl LayerWeights {
let (_b_sz, _n_head, seq_len, _n_embd) = x.dims4()?;
let cos = self.cos.narrow(0, index_pos, seq_len)?;
let sin = self.sin.narrow(0, index_pos, seq_len)?;
+ // The call to contiguous below is only necessary when processing the prompt.
+ // When the seq_len is 1 in the inference loop, this is a no-op.
candle_nn::rotary_emb::rope_i(&x.contiguous()?, &cos, &sin)
}
@@ -180,7 +182,11 @@ impl LayerWeights {
.transpose(1, 2)?;
let v = v
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
- .transpose(1, 2)?;
+ .transpose(1, 2)?
+ // This call to contiguous ensures that the fast kernel can be called below. It's
+ // actually a no-op except when processing the initial prompt so has no significant
+ // impact on performance.
+ .contiguous()?;
let q = self.apply_rotary_emb(&q, index_pos)?;
let k = self.apply_rotary_emb(&k, index_pos)?;
@@ -191,8 +197,8 @@ impl LayerWeights {
if index_pos == 0 {
(k, v)
} else {
- let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?;
- let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?;
+ let k = Tensor::cat(&[k_cache, &k], 2)?;
+ let v = Tensor::cat(&[v_cache, &v], 2)?;
(k, v)
}
}
@@ -486,7 +492,7 @@ impl ModelWeights {
layer_in = x
}
let x = self.norm.forward(&layer_in)?;
- let x = x.i((.., seq_len - 1, ..))?.contiguous()?;
+ let x = x.i((.., seq_len - 1, ..))?;
let _enter = self.span_output.enter();
self.output.forward(&x)
}