diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-09-26 21:00:18 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-26 21:00:18 +0200 |
commit | ad8a4c5e5a2f177e3b90ac812d66d6a5ed1d69dc (patch) | |
tree | d6514faca57dd0204170e04d6d6a94ca295fe278 /candle-transformers | |
parent | c3c392f45c14f60eb4fb8397cc5c1d3891c9656d (diff) | |
download | candle-ad8a4c5e5a2f177e3b90ac812d66d6a5ed1d69dc.tar.gz candle-ad8a4c5e5a2f177e3b90ac812d66d6a5ed1d69dc.tar.bz2 candle-ad8a4c5e5a2f177e3b90ac812d66d6a5ed1d69dc.zip |
Add some llama-3.2 examples. (#2508)
* Add some llama-3.2 examples.
* Support tie-word-embeddings for llama.
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/llama.rs | 11 | ||||
-rw-r--r-- | candle-transformers/src/models/llava/config.rs | 3 |
2 files changed, 13 insertions, 1 deletions
diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index e96bb855..a7bef099 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -44,6 +44,7 @@ pub struct LlamaConfig { pub eos_token_id: Option<LlamaEosToks>, pub rope_scaling: Option<Llama3RopeConfig>, pub max_position_embeddings: usize, + pub tie_word_embeddings: Option<bool>, } impl LlamaConfig { @@ -72,6 +73,7 @@ impl LlamaConfig { eos_token_id: self.eos_token_id, rope_scaling: self.rope_scaling, max_position_embeddings: self.max_position_embeddings, + tie_word_embeddings: self.tie_word_embeddings.unwrap_or(false), } } } @@ -91,6 +93,7 @@ pub struct Config { pub eos_token_id: Option<LlamaEosToks>, pub rope_scaling: Option<Llama3RopeConfig>, pub max_position_embeddings: usize, + pub tie_word_embeddings: bool, } impl Config { @@ -109,6 +112,7 @@ impl Config { eos_token_id: None, rope_scaling: None, max_position_embeddings: DEFAULT_MAX_SEQ_LEN, + tie_word_embeddings: false, } } @@ -127,6 +131,7 @@ impl Config { eos_token_id: None, rope_scaling: None, max_position_embeddings: DEFAULT_MAX_SEQ_LEN, + tie_word_embeddings: false, } } } @@ -504,7 +509,11 @@ impl Llama { pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; - let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + let lm_head = if cfg.tie_word_embeddings { + Linear::from_weights(wte.embeddings().clone(), None) + } else { + linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))? + }; let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.num_hidden_layers) .map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cfg).unwrap()) diff --git a/candle-transformers/src/models/llava/config.rs b/candle-transformers/src/models/llava/config.rs index 5dca6870..405eedb9 100644 --- a/candle-transformers/src/models/llava/config.rs +++ b/candle-transformers/src/models/llava/config.rs @@ -43,6 +43,7 @@ pub struct LLaVAConfig { pub image_token_index: isize, #[serde(default = "default_hf")] pub hf: bool, + pub tie_word_embeddings: Option<bool>, } fn default_hf() -> bool { @@ -77,6 +78,7 @@ impl LLaVAConfig { use_flash_attn: false, rope_scaling: None, // Assume we don't have LLaVA for Llama 3.1 max_position_embeddings: self.max_position_embeddings, + tie_word_embeddings: self.tie_word_embeddings.unwrap_or(false), } } } @@ -264,6 +266,7 @@ impl HFLLaVAConfig { use_cache: self.text_config.use_cache, vocab_size: self.vocab_size, image_token_index: self.image_token_index, + tie_word_embeddings: None, } } } |