summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-11-27 20:42:52 +0000
committerGitHub <noreply@github.com>2023-11-27 20:42:52 +0000
commit7c3cfd1086ecdc08a0b350f30f1fbedf2f00c269 (patch)
treee53768f2111928259f615535be7ff65a941ab3ac
parente2eb6590ed57438f1c428112f0f444c1bc2acd6d (diff)
downloadcandle-7c3cfd1086ecdc08a0b350f30f1fbedf2f00c269.tar.gz
candle-7c3cfd1086ecdc08a0b350f30f1fbedf2f00c269.tar.bz2
candle-7c3cfd1086ecdc08a0b350f30f1fbedf2f00c269.zip
Use the llama weight names for the Yi example. (#1381)
-rw-r--r--candle-examples/examples/yi/main.rs4
-rw-r--r--candle-transformers/src/models/yi.rs8
2 files changed, 8 insertions, 4 deletions
diff --git a/candle-examples/examples/yi/main.rs b/candle-examples/examples/yi/main.rs
index 7bb3be4a..a7184db9 100644
--- a/candle-examples/examples/yi/main.rs
+++ b/candle-examples/examples/yi/main.rs
@@ -74,9 +74,9 @@ impl TextGeneration {
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
- let eos_token = match self.tokenizer.get_token("</s>") {
+ let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
- None => anyhow::bail!("cannot find the </s> token"),
+ None => anyhow::bail!("cannot find the <|endoftext|> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
diff --git a/candle-transformers/src/models/yi.rs b/candle-transformers/src/models/yi.rs
index 0009ece6..14b6feeb 100644
--- a/candle-transformers/src/models/yi.rs
+++ b/candle-transformers/src/models/yi.rs
@@ -277,8 +277,12 @@ impl DecoderLayer {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
- let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("ln1"))?;
- let ln2 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("ln2"))?;
+ let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
+ let ln2 = RmsNorm::new(
+ cfg.hidden_size,
+ cfg.rms_norm_eps,
+ vb.pp("post_attention_layernorm"),
+ )?;
Ok(Self {
self_attn,
mlp,