From ad8a4c5e5a2f177e3b90ac812d66d6a5ed1d69dc Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 26 Sep 2024 21:00:18 +0200 Subject: Add some llama-3.2 examples. (#2508) * Add some llama-3.2 examples. * Support tie-word-embeddings for llama. --- candle-transformers/src/models/llama.rs | 11 ++++++++++- candle-transformers/src/models/llava/config.rs | 3 +++ 2 files changed, 13 insertions(+), 1 deletion(-) (limited to 'candle-transformers') diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index e96bb855..a7bef099 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -44,6 +44,7 @@ pub struct LlamaConfig { pub eos_token_id: Option, pub rope_scaling: Option, pub max_position_embeddings: usize, + pub tie_word_embeddings: Option, } impl LlamaConfig { @@ -72,6 +73,7 @@ impl LlamaConfig { eos_token_id: self.eos_token_id, rope_scaling: self.rope_scaling, max_position_embeddings: self.max_position_embeddings, + tie_word_embeddings: self.tie_word_embeddings.unwrap_or(false), } } } @@ -91,6 +93,7 @@ pub struct Config { pub eos_token_id: Option, pub rope_scaling: Option, pub max_position_embeddings: usize, + pub tie_word_embeddings: bool, } impl Config { @@ -109,6 +112,7 @@ impl Config { eos_token_id: None, rope_scaling: None, max_position_embeddings: DEFAULT_MAX_SEQ_LEN, + tie_word_embeddings: false, } } @@ -127,6 +131,7 @@ impl Config { eos_token_id: None, rope_scaling: None, max_position_embeddings: DEFAULT_MAX_SEQ_LEN, + tie_word_embeddings: false, } } } @@ -504,7 +509,11 @@ impl Llama { pub fn load(vb: VarBuilder, cfg: &Config) -> Result { let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; - let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + let lm_head = if cfg.tie_word_embeddings { + Linear::from_weights(wte.embeddings().clone(), None) + } else { + linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))? + }; let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.num_hidden_layers) .map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cfg).unwrap()) diff --git a/candle-transformers/src/models/llava/config.rs b/candle-transformers/src/models/llava/config.rs index 5dca6870..405eedb9 100644 --- a/candle-transformers/src/models/llava/config.rs +++ b/candle-transformers/src/models/llava/config.rs @@ -43,6 +43,7 @@ pub struct LLaVAConfig { pub image_token_index: isize, #[serde(default = "default_hf")] pub hf: bool, + pub tie_word_embeddings: Option, } fn default_hf() -> bool { @@ -77,6 +78,7 @@ impl LLaVAConfig { use_flash_attn: false, rope_scaling: None, // Assume we don't have LLaVA for Llama 3.1 max_position_embeddings: self.max_position_embeddings, + tie_word_embeddings: self.tie_word_embeddings.unwrap_or(false), } } } @@ -264,6 +266,7 @@ impl HFLLaVAConfig { use_cache: self.text_config.use_cache, vocab_size: self.vocab_size, image_token_index: self.image_token_index, + tie_word_embeddings: None, } } } -- cgit v1.2.3