diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-11-27 20:42:52 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-27 20:42:52 +0000 |
commit | 7c3cfd1086ecdc08a0b350f30f1fbedf2f00c269 (patch) | |
tree | e53768f2111928259f615535be7ff65a941ab3ac | |
parent | e2eb6590ed57438f1c428112f0f444c1bc2acd6d (diff) | |
download | candle-7c3cfd1086ecdc08a0b350f30f1fbedf2f00c269.tar.gz candle-7c3cfd1086ecdc08a0b350f30f1fbedf2f00c269.tar.bz2 candle-7c3cfd1086ecdc08a0b350f30f1fbedf2f00c269.zip |
Use the llama weight names for the Yi example. (#1381)
-rw-r--r-- | candle-examples/examples/yi/main.rs | 4 | ||||
-rw-r--r-- | candle-transformers/src/models/yi.rs | 8 |
2 files changed, 8 insertions, 4 deletions
diff --git a/candle-examples/examples/yi/main.rs b/candle-examples/examples/yi/main.rs index 7bb3be4a..a7184db9 100644 --- a/candle-examples/examples/yi/main.rs +++ b/candle-examples/examples/yi/main.rs @@ -74,9 +74,9 @@ impl TextGeneration { std::io::stdout().flush()?; let mut generated_tokens = 0usize; - let eos_token = match self.tokenizer.get_token("</s>") { + let eos_token = match self.tokenizer.get_token("<|endoftext|>") { Some(token) => token, - None => anyhow::bail!("cannot find the </s> token"), + None => anyhow::bail!("cannot find the <|endoftext|> token"), }; let start_gen = std::time::Instant::now(); for index in 0..sample_len { diff --git a/candle-transformers/src/models/yi.rs b/candle-transformers/src/models/yi.rs index 0009ece6..14b6feeb 100644 --- a/candle-transformers/src/models/yi.rs +++ b/candle-transformers/src/models/yi.rs @@ -277,8 +277,12 @@ impl DecoderLayer { fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> { let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; let mlp = MLP::new(cfg, vb.pp("mlp"))?; - let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("ln1"))?; - let ln2 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("ln2"))?; + let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let ln2 = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; Ok(Self { self_attn, mlp, |