summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-06-07 10:51:50 +0100
committerGitHub <noreply@github.com>2024-06-07 10:51:50 +0100
commit54ff971e35a0fd28da062d416ffb7bc9ac9d40d8 (patch)
treec5ee47770b4f1195bc66e0bdbe75a630a4ccadbc /candle-transformers
parentb9fac7ec008bfccf8900552f51e6d0e865280ee9 (diff)
downloadcandle-54ff971e35a0fd28da062d416ffb7bc9ac9d40d8.tar.gz
candle-54ff971e35a0fd28da062d416ffb7bc9ac9d40d8.tar.bz2
candle-54ff971e35a0fd28da062d416ffb7bc9ac9d40d8.zip
Support for the new Qwen2 models. (#2257)
* Support for the new Qwen2 models. * Add more models.
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/qwen2.rs8
1 files changed, 6 insertions, 2 deletions
diff --git a/candle-transformers/src/models/qwen2.rs b/candle-transformers/src/models/qwen2.rs
index 16ee8b01..3dce5c6a 100644
--- a/candle-transformers/src/models/qwen2.rs
+++ b/candle-transformers/src/models/qwen2.rs
@@ -360,8 +360,12 @@ pub struct ModelForCausalLM {
impl ModelForCausalLM {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
- let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
- let base_model = Model::new(cfg, vb)?;
+ let base_model = Model::new(cfg, vb.clone())?;
+ let lm_head = if vb.contains_tensor("lm_head") {
+ linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
+ } else {
+ Linear::from_weights(base_model.embed_tokens.embeddings().clone(), None)
+ };
Ok(Self {
base_model,
lm_head,