summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/qwen2.rs8
1 files changed, 6 insertions, 2 deletions
diff --git a/candle-transformers/src/models/qwen2.rs b/candle-transformers/src/models/qwen2.rs
index 16ee8b01..3dce5c6a 100644
--- a/candle-transformers/src/models/qwen2.rs
+++ b/candle-transformers/src/models/qwen2.rs
@@ -360,8 +360,12 @@ pub struct ModelForCausalLM {
impl ModelForCausalLM {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
- let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
- let base_model = Model::new(cfg, vb)?;
+ let base_model = Model::new(cfg, vb.clone())?;
+ let lm_head = if vb.contains_tensor("lm_head") {
+ linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
+ } else {
+ Linear::from_weights(base_model.embed_tokens.embeddings().clone(), None)
+ };
Ok(Self {
base_model,
lm_head,