diff options
Diffstat (limited to 'candle-transformers/src/models/llama.rs')
-rw-r--r-- | candle-transformers/src/models/llama.rs | 11 |
1 files changed, 10 insertions, 1 deletions
diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index e96bb855..a7bef099 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -44,6 +44,7 @@ pub struct LlamaConfig { pub eos_token_id: Option<LlamaEosToks>, pub rope_scaling: Option<Llama3RopeConfig>, pub max_position_embeddings: usize, + pub tie_word_embeddings: Option<bool>, } impl LlamaConfig { @@ -72,6 +73,7 @@ impl LlamaConfig { eos_token_id: self.eos_token_id, rope_scaling: self.rope_scaling, max_position_embeddings: self.max_position_embeddings, + tie_word_embeddings: self.tie_word_embeddings.unwrap_or(false), } } } @@ -91,6 +93,7 @@ pub struct Config { pub eos_token_id: Option<LlamaEosToks>, pub rope_scaling: Option<Llama3RopeConfig>, pub max_position_embeddings: usize, + pub tie_word_embeddings: bool, } impl Config { @@ -109,6 +112,7 @@ impl Config { eos_token_id: None, rope_scaling: None, max_position_embeddings: DEFAULT_MAX_SEQ_LEN, + tie_word_embeddings: false, } } @@ -127,6 +131,7 @@ impl Config { eos_token_id: None, rope_scaling: None, max_position_embeddings: DEFAULT_MAX_SEQ_LEN, + tie_word_embeddings: false, } } } @@ -504,7 +509,11 @@ impl Llama { pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; - let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + let lm_head = if cfg.tie_word_embeddings { + Linear::from_weights(wte.embeddings().clone(), None) + } else { + linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))? + }; let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.num_hidden_layers) .map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cfg).unwrap()) |