diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-03-30 13:22:00 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-30 13:22:00 +0100 |
commit | b190fd85920dfeb93c091593d42fda596c3a83a7 (patch) | |
tree | 52003bc7540277e37f006198e77a145fbf8dccd4 | |
parent | efe4a0c84b55b60f7555a89ea7e0ba8d300104cd (diff) | |
download | candle-b190fd85920dfeb93c091593d42fda596c3a83a7.tar.gz candle-b190fd85920dfeb93c091593d42fda596c3a83a7.tar.bz2 candle-b190fd85920dfeb93c091593d42fda596c3a83a7.zip |
Remove some unnecessary calls to contiguous. (#1968)
* Remove some unnecessary calls to contiguous.
* Slightly improved kv cache concatenation.
-rw-r--r-- | candle-core/src/tensor_cat.rs | 22 | ||||
-rw-r--r-- | candle-transformers/src/models/quantized_llama.rs | 14 |
2 files changed, 20 insertions, 16 deletions
diff --git a/candle-core/src/tensor_cat.rs b/candle-core/src/tensor_cat.rs index 31cc8503..27ff7851 100644 --- a/candle-core/src/tensor_cat.rs +++ b/candle-core/src/tensor_cat.rs @@ -58,20 +58,18 @@ impl Tensor { } } } - if dim == 0 { + let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous()); + if all_contiguous { + Self::cat_contiguous(args, dim) + } else if dim == 0 { Self::cat0(args) } else { - let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous()); - if all_contiguous { - Self::cat_contiguous(args, dim) - } else { - let args: Vec<Tensor> = args - .iter() - .map(|a| a.as_ref().transpose(0, dim)) - .collect::<Result<Vec<_>>>()?; - let cat = Self::cat0(&args)?; - cat.transpose(0, dim) - } + let args: Vec<Tensor> = args + .iter() + .map(|a| a.as_ref().transpose(0, dim)) + .collect::<Result<Vec<_>>>()?; + let cat = Self::cat0(&args)?; + cat.transpose(0, dim) } } diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 9898d872..e1519b2d 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -157,6 +157,8 @@ impl LayerWeights { let (_b_sz, _n_head, seq_len, _n_embd) = x.dims4()?; let cos = self.cos.narrow(0, index_pos, seq_len)?; let sin = self.sin.narrow(0, index_pos, seq_len)?; + // The call to contiguous below is only necessary when processing the prompt. + // When the seq_len is 1 in the inference loop, this is a no-op. candle_nn::rotary_emb::rope_i(&x.contiguous()?, &cos, &sin) } @@ -180,7 +182,11 @@ impl LayerWeights { .transpose(1, 2)?; let v = v .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? - .transpose(1, 2)?; + .transpose(1, 2)? + // This call to contiguous ensures that the fast kernel can be called below. It's + // actually a no-op except when processing the initial prompt so has no significant + // impact on performance. + .contiguous()?; let q = self.apply_rotary_emb(&q, index_pos)?; let k = self.apply_rotary_emb(&k, index_pos)?; @@ -191,8 +197,8 @@ impl LayerWeights { if index_pos == 0 { (k, v) } else { - let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?; - let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?; + let k = Tensor::cat(&[k_cache, &k], 2)?; + let v = Tensor::cat(&[v_cache, &v], 2)?; (k, v) } } @@ -486,7 +492,7 @@ impl ModelWeights { layer_in = x } let x = self.norm.forward(&layer_in)?; - let x = x.i((.., seq_len - 1, ..))?.contiguous()?; + let x = x.i((.., seq_len - 1, ..))?; let _enter = self.span_output.enter(); self.output.forward(&x) } |