summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/yi/main.rs4
-rw-r--r--candle-transformers/src/models/yi.rs8
2 files changed, 8 insertions, 4 deletions
diff --git a/candle-examples/examples/yi/main.rs b/candle-examples/examples/yi/main.rs
index 7bb3be4a..a7184db9 100644
--- a/candle-examples/examples/yi/main.rs
+++ b/candle-examples/examples/yi/main.rs
@@ -74,9 +74,9 @@ impl TextGeneration {
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
- let eos_token = match self.tokenizer.get_token("</s>") {
+ let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
- None => anyhow::bail!("cannot find the </s> token"),
+ None => anyhow::bail!("cannot find the <|endoftext|> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
diff --git a/candle-transformers/src/models/yi.rs b/candle-transformers/src/models/yi.rs
index 0009ece6..14b6feeb 100644
--- a/candle-transformers/src/models/yi.rs
+++ b/candle-transformers/src/models/yi.rs
@@ -277,8 +277,12 @@ impl DecoderLayer {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
- let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("ln1"))?;
- let ln2 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("ln2"))?;
+ let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
+ let ln2 = RmsNorm::new(
+ cfg.hidden_size,
+ cfg.rms_norm_eps,
+ vb.pp("post_attention_layernorm"),
+ )?;
Ok(Self {
self_attn,
mlp,