summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-09-26 21:00:18 +0200
committerGitHub <noreply@github.com>2024-09-26 21:00:18 +0200
commitad8a4c5e5a2f177e3b90ac812d66d6a5ed1d69dc (patch)
treed6514faca57dd0204170e04d6d6a94ca295fe278 /candle-transformers
parentc3c392f45c14f60eb4fb8397cc5c1d3891c9656d (diff)
downloadcandle-ad8a4c5e5a2f177e3b90ac812d66d6a5ed1d69dc.tar.gz
candle-ad8a4c5e5a2f177e3b90ac812d66d6a5ed1d69dc.tar.bz2
candle-ad8a4c5e5a2f177e3b90ac812d66d6a5ed1d69dc.zip
Add some llama-3.2 examples. (#2508)
* Add some llama-3.2 examples. * Support tie-word-embeddings for llama.
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/llama.rs11
-rw-r--r--candle-transformers/src/models/llava/config.rs3
2 files changed, 13 insertions, 1 deletions
diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs
index e96bb855..a7bef099 100644
--- a/candle-transformers/src/models/llama.rs
+++ b/candle-transformers/src/models/llama.rs
@@ -44,6 +44,7 @@ pub struct LlamaConfig {
pub eos_token_id: Option<LlamaEosToks>,
pub rope_scaling: Option<Llama3RopeConfig>,
pub max_position_embeddings: usize,
+ pub tie_word_embeddings: Option<bool>,
}
impl LlamaConfig {
@@ -72,6 +73,7 @@ impl LlamaConfig {
eos_token_id: self.eos_token_id,
rope_scaling: self.rope_scaling,
max_position_embeddings: self.max_position_embeddings,
+ tie_word_embeddings: self.tie_word_embeddings.unwrap_or(false),
}
}
}
@@ -91,6 +93,7 @@ pub struct Config {
pub eos_token_id: Option<LlamaEosToks>,
pub rope_scaling: Option<Llama3RopeConfig>,
pub max_position_embeddings: usize,
+ pub tie_word_embeddings: bool,
}
impl Config {
@@ -109,6 +112,7 @@ impl Config {
eos_token_id: None,
rope_scaling: None,
max_position_embeddings: DEFAULT_MAX_SEQ_LEN,
+ tie_word_embeddings: false,
}
}
@@ -127,6 +131,7 @@ impl Config {
eos_token_id: None,
rope_scaling: None,
max_position_embeddings: DEFAULT_MAX_SEQ_LEN,
+ tie_word_embeddings: false,
}
}
}
@@ -504,7 +509,11 @@ impl Llama {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
- let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
+ let lm_head = if cfg.tie_word_embeddings {
+ Linear::from_weights(wte.embeddings().clone(), None)
+ } else {
+ linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
+ };
let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cfg).unwrap())
diff --git a/candle-transformers/src/models/llava/config.rs b/candle-transformers/src/models/llava/config.rs
index 5dca6870..405eedb9 100644
--- a/candle-transformers/src/models/llava/config.rs
+++ b/candle-transformers/src/models/llava/config.rs
@@ -43,6 +43,7 @@ pub struct LLaVAConfig {
pub image_token_index: isize,
#[serde(default = "default_hf")]
pub hf: bool,
+ pub tie_word_embeddings: Option<bool>,
}
fn default_hf() -> bool {
@@ -77,6 +78,7 @@ impl LLaVAConfig {
use_flash_attn: false,
rope_scaling: None, // Assume we don't have LLaVA for Llama 3.1
max_position_embeddings: self.max_position_embeddings,
+ tie_word_embeddings: self.tie_word_embeddings.unwrap_or(false),
}
}
}
@@ -264,6 +266,7 @@ impl HFLLaVAConfig {
use_cache: self.text_config.use_cache,
vocab_size: self.vocab_size,
image_token_index: self.image_token_index,
+ tie_word_embeddings: None,
}
}
}