diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-06-07 10:51:50 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-07 10:51:50 +0100 |
commit | 54ff971e35a0fd28da062d416ffb7bc9ac9d40d8 (patch) | |
tree | c5ee47770b4f1195bc66e0bdbe75a630a4ccadbc /candle-transformers | |
parent | b9fac7ec008bfccf8900552f51e6d0e865280ee9 (diff) | |
download | candle-54ff971e35a0fd28da062d416ffb7bc9ac9d40d8.tar.gz candle-54ff971e35a0fd28da062d416ffb7bc9ac9d40d8.tar.bz2 candle-54ff971e35a0fd28da062d416ffb7bc9ac9d40d8.zip |
Support for the new Qwen2 models. (#2257)
* Support for the new Qwen2 models.
* Add more models.
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/qwen2.rs | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/candle-transformers/src/models/qwen2.rs b/candle-transformers/src/models/qwen2.rs index 16ee8b01..3dce5c6a 100644 --- a/candle-transformers/src/models/qwen2.rs +++ b/candle-transformers/src/models/qwen2.rs @@ -360,8 +360,12 @@ pub struct ModelForCausalLM { impl ModelForCausalLM { pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { - let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; - let base_model = Model::new(cfg, vb)?; + let base_model = Model::new(cfg, vb.clone())?; + let lm_head = if vb.contains_tensor("lm_head") { + linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))? + } else { + Linear::from_weights(base_model.embed_tokens.embeddings().clone(), None) + }; Ok(Self { base_model, lm_head, |