diff options
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/qwen2.rs | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/candle-transformers/src/models/qwen2.rs b/candle-transformers/src/models/qwen2.rs index 16ee8b01..3dce5c6a 100644 --- a/candle-transformers/src/models/qwen2.rs +++ b/candle-transformers/src/models/qwen2.rs @@ -360,8 +360,12 @@ pub struct ModelForCausalLM { impl ModelForCausalLM { pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { - let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; - let base_model = Model::new(cfg, vb)?; + let base_model = Model::new(cfg, vb.clone())?; + let lm_head = if vb.contains_tensor("lm_head") { + linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))? + } else { + Linear::from_weights(base_model.embed_tokens.embeddings().clone(), None) + }; Ok(Self { base_model, lm_head, |