summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorYin Guobing <yinguobing@gmail.com>2024-05-17 03:34:10 +0800
committerGitHub <noreply@github.com>2024-05-16 21:34:10 +0200
commit349c3e806a15399df8289c41b2e24c3fa24b6d84 (patch)
treec0e0f625c115b3e97c04ab9281122d814ad027db /candle-transformers
parentbdaa34216a2bb3527b6e248030f434561f9cf620 (diff)
downloadcandle-349c3e806a15399df8289c41b2e24c3fa24b6d84.tar.gz
candle-349c3e806a15399df8289c41b2e24c3fa24b6d84.tar.bz2
candle-349c3e806a15399df8289c41b2e24c3fa24b6d84.zip
Support embedding model gte-Qwen1.5-7B-instruct (#2190)
* Support embedding model gte-Qwen1.5-7B-instruct This is a text embedding model based on Qwen2. They share same model architecture except the last MLP module. This commit brings in minimal modification of the old Qwen2 implementation to support both models. An example is provided, and had been verified according to the official PyTorch implementation. * Avoid doing the 'last-token filtering' based on the absence of attention mask. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/qwen2.rs77
1 files changed, 62 insertions, 15 deletions
diff --git a/candle-transformers/src/models/qwen2.rs b/candle-transformers/src/models/qwen2.rs
index c9b5ae01..16ee8b01 100644
--- a/candle-transformers/src/models/qwen2.rs
+++ b/candle-transformers/src/models/qwen2.rs
@@ -1,5 +1,5 @@
use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm};
-use candle::{DType, Device, Module, Result, Tensor, D};
+use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use std::sync::Arc;
@@ -250,7 +250,6 @@ pub struct Model {
embed_tokens: candle_nn::Embedding,
layers: Vec<DecoderLayer>,
norm: RmsNorm,
- lm_head: Linear,
sliding_window: usize,
device: Device,
dtype: DType,
@@ -269,19 +268,17 @@ impl Model {
layers.push(layer)
}
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
- let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
Ok(Self {
embed_tokens,
layers,
norm,
- lm_head,
sliding_window: cfg.sliding_window,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
- fn prepare_decoder_attention_mask(
+ fn prepare_causal_attention_mask(
&self,
b_size: usize,
tgt_len: usize,
@@ -301,7 +298,7 @@ impl Model {
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
- let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
+ let mask0 = Tensor::zeros((tgt_len, seqlen_offset), self.dtype, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
@@ -310,21 +307,42 @@ impl Model {
.to_dtype(self.dtype)
}
- pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
+ fn prepare_attention_mask(&self, attn_mask: &Tensor) -> Result<Tensor> {
+ let (b_sz, sql_len) = attn_mask.dims2()?;
+ let mut mask: Vec<Tensor> = vec![];
+ for b in 0..b_sz {
+ mask.push(attn_mask.i((b, ..))?.expand((1, 1, sql_len, sql_len))?);
+ }
+ let mask = Tensor::cat(&mask, 0)?;
+ let on_true = mask.zeros_like()?.to_dtype(self.dtype)?;
+ let on_false = Tensor::new(f32::NEG_INFINITY, &self.device)?
+ .broadcast_as(mask.shape())?
+ .to_dtype(self.dtype)?;
+ mask.where_cond(&on_true, &on_false)
+ }
+
+ pub fn forward(
+ &mut self,
+ input_ids: &Tensor,
+ seqlen_offset: usize,
+ attn_mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
let (b_size, seq_len) = input_ids.dims2()?;
- let attention_mask = if seq_len <= 1 {
- None
- } else {
- let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
- Some(mask)
+ let attention_mask: Option<Tensor> = match attn_mask {
+ Some(mask) => Some(self.prepare_attention_mask(mask)?),
+ None => {
+ if seq_len <= 1 {
+ None
+ } else {
+ Some(self.prepare_causal_attention_mask(b_size, seq_len, seqlen_offset)?)
+ }
+ }
};
let mut xs = self.embed_tokens.forward(input_ids)?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
}
- xs.narrow(1, seq_len - 1, 1)?
- .apply(&self.norm)?
- .apply(&self.lm_head)
+ xs.apply(&self.norm)
}
pub fn clear_kv_cache(&mut self) {
@@ -333,3 +351,32 @@ impl Model {
}
}
}
+
+#[derive(Debug, Clone)]
+pub struct ModelForCausalLM {
+ base_model: Model,
+ lm_head: Linear,
+}
+
+impl ModelForCausalLM {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
+ let base_model = Model::new(cfg, vb)?;
+ Ok(Self {
+ base_model,
+ lm_head,
+ })
+ }
+
+ pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
+ let (_b_size, seq_len) = input_ids.dims2()?;
+ self.base_model
+ .forward(input_ids, seqlen_offset, None)?
+ .narrow(1, seq_len - 1, 1)?
+ .apply(&self.lm_head)
+ }
+
+ pub fn clear_kv_cache(&mut self) {
+ self.base_model.clear_kv_cache()
+ }
+}