diff options
Diffstat (limited to 'candle-examples/examples/llama2-c')
-rw-r--r-- | candle-examples/examples/llama2-c/main.rs | 131 | ||||
-rw-r--r-- | candle-examples/examples/llama2-c/model.rs | 19 |
2 files changed, 85 insertions, 65 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 2e762f98..9b6d1316 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -13,6 +13,7 @@ use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use candle::{DType, Device, Error, IndexOp, Layout, Shape, Tensor}; use candle_nn::{Embedding, Linear, VarBuilder}; use candle_transformers::generation::LogitsProcessor; +use std::io::Write; use model::{Config, Llama}; @@ -38,21 +39,33 @@ struct TransformerWeights { freq_cis_imag: Tensor, // (seq_len, head_size/2) } -impl Config { - fn read_i32<R: std::io::Read>(r: &mut R) -> Result<i32> { - let mut buf = [0u8; 4]; - r.read_exact(&mut buf)?; - Ok(i32::from_le_bytes(buf)) - } +fn read_i32<R: std::io::Read>(r: &mut R) -> Result<i32> { + let mut buf = [0u8; 4]; + r.read_exact(&mut buf)?; + Ok(i32::from_le_bytes(buf)) +} +fn read_tensor<R: std::io::Read, S: Into<Shape>>( + r: &mut R, + shape: S, + dev: &Device, +) -> Result<Tensor> { + let shape = shape.into(); + let mut data_t = vec![0f32; shape.elem_count()]; + r.read_f32_into::<LittleEndian>(&mut data_t)?; + let tensor = Tensor::from_vec(data_t, shape, dev)?; + Ok(tensor) +} + +impl Config { fn from_reader<R: std::io::Read>(r: &mut R) -> Result<Self> { - let dim = Self::read_i32(r)? as usize; - let hidden_dim = Self::read_i32(r)? as usize; - let n_layers = Self::read_i32(r)? as usize; - let n_heads = Self::read_i32(r)? as usize; - let n_kv_heads = Self::read_i32(r)? as usize; - let vocab_size = Self::read_i32(r)? as usize; - let seq_len = Self::read_i32(r)? as usize; + let dim = read_i32(r)? as usize; + let hidden_dim = read_i32(r)? as usize; + let n_layers = read_i32(r)? as usize; + let n_heads = read_i32(r)? as usize; + let n_kv_heads = read_i32(r)? as usize; + let vocab_size = read_i32(r)? as usize; + let seq_len = read_i32(r)? as usize; Ok(Self { dim, hidden_dim, @@ -71,33 +84,21 @@ impl Config { } impl TransformerWeights { - fn read_tensor<R: std::io::Read, S: Into<Shape>>( - r: &mut R, - shape: S, - dev: &Device, - ) -> Result<Tensor> { - let shape = shape.into(); - let mut data_t = vec![0f32; shape.elem_count()]; - r.read_f32_into::<LittleEndian>(&mut data_t)?; - let tensor = Tensor::from_vec(data_t, shape, dev)?; - Ok(tensor) - } - fn from_reader<R: std::io::Read>(r: &mut R, c: &Config, dev: &Device) -> Result<Self> { - let token_embedding_table = Self::read_tensor(r, (c.vocab_size, c.dim), dev)?; - let rms_att_weight = Self::read_tensor(r, (c.n_layers, c.dim), dev)?; - let wq = Self::read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; - let wk = Self::read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; - let wv = Self::read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; - let wo = Self::read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; - let rms_ffn_weight = Self::read_tensor(r, (c.n_layers, c.dim), dev)?; - let w1 = Self::read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?; - let w2 = Self::read_tensor(r, (c.n_layers, c.dim, c.hidden_dim), dev)?; - let w3 = Self::read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?; - let rms_final_weight = Self::read_tensor(r, c.dim, dev)?; + let token_embedding_table = read_tensor(r, (c.vocab_size, c.dim), dev)?; + let rms_att_weight = read_tensor(r, (c.n_layers, c.dim), dev)?; + let wq = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; + let wk = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; + let wv = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; + let wo = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; + let rms_ffn_weight = read_tensor(r, (c.n_layers, c.dim), dev)?; + let w1 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?; + let w2 = read_tensor(r, (c.n_layers, c.dim, c.hidden_dim), dev)?; + let w3 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?; + let rms_final_weight = read_tensor(r, c.dim, dev)?; let head_size = c.head_size(); - let freq_cis_real = Self::read_tensor(r, (c.seq_len, head_size / 2), dev)?; - let freq_cis_imag = Self::read_tensor(r, (c.seq_len, head_size / 2), dev)?; + let freq_cis_real = read_tensor(r, (c.seq_len, head_size / 2), dev)?; + let freq_cis_imag = read_tensor(r, (c.seq_len, head_size / 2), dev)?; Ok(Self { token_embedding_table, rms_att_weight, @@ -181,13 +182,36 @@ struct Args { /// Config file in binary format. #[arg(long)] config: String, + + /// Tokenizer config file in binary format. + #[arg(long)] + tokenizer: String, + + /// The temperature used to generate samples. + #[arg(long)] + temperature: Option<f64>, +} + +struct Tokenizer { + tokens: Vec<String>, +} + +impl Tokenizer { + fn from_reader<R: std::io::Read>(r: &mut R, c: &Config) -> Result<Self> { + let mut tokens = Vec::with_capacity(c.vocab_size); + for _token_index in 0..c.vocab_size { + let token_len = read_i32(r)?; + let mut token = vec![0u8; token_len as usize]; + r.read_exact(&mut token); + tokens.push(String::from_utf8_lossy(&token).into_owned()) + } + Ok(Self { tokens }) + } } fn main() -> anyhow::Result<()> { let args = Args::parse(); let device = candle_examples::device(args.cpu)?; - let t = Tensor::arange(0f32, 14f32, &device)?.reshape((2, 7))?; - println!("{t}"); let mut file = std::fs::File::open(&args.config)?; let config = Config::from_reader(&mut file)?; println!("config: {config:?}"); @@ -196,13 +220,15 @@ fn main() -> anyhow::Result<()> { let cache = model::Cache::new(true, &config, vb.pp("rot"))?; let model = Llama::load(vb, &cache, &config)?; + let mut file = std::fs::File::open(&args.tokenizer)?; + let tokenizer = Tokenizer::from_reader(&mut file, &config)?; + println!("starting the inference loop"); - let mut logits_processor = LogitsProcessor::new(299792458, None); - let mut new_tokens: Vec<u32> = vec![]; - let start_gen = std::time::Instant::now(); + let mut logits_processor = LogitsProcessor::new(299792458, args.temperature); let mut index_pos = 0; let mut tokens = vec![1u32]; + let start_gen = std::time::Instant::now(); for index in 0..config.seq_len - 10 { let start_gen = std::time::Instant::now(); let context_size = if cache.use_kv_cache && index > 0 { @@ -218,23 +244,14 @@ fn main() -> anyhow::Result<()> { let next_token = logits_processor.sample(&logits)?; tokens.push(next_token); - new_tokens.push(next_token); - println!("> {:?}", start_gen.elapsed()); - println!( - "{} token: {} '{}'", - index + 1, - next_token, - 0, - // tokenizer.decode(vec![next_token], true).map_err(E::msg)? - ); + print!("{}", tokenizer.tokens[next_token as usize]); + std::io::stdout().flush()?; } let dt = start_gen.elapsed(); println!( - "{} tokens generated ({} token/s)\n----\n{}\n----", - config.seq_len, - config.seq_len as f64 / dt.as_secs_f64(), - 0, - // tokenizer.decode(new_tokens, true).map_err(E::msg)? + "\n{} tokens generated ({:.2} token/s)\n", + tokens.len(), + tokens.len() as f64 / dt.as_secs_f64(), ); Ok(()) } diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-examples/examples/llama2-c/model.rs index 2fb4b444..13f939db 100644 --- a/candle-examples/examples/llama2-c/model.rs +++ b/candle-examples/examples/llama2-c/model.rs @@ -30,12 +30,14 @@ impl Cache { pub fn new(use_kv_cache: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> { let freq_cis_real = vb.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_real")?; let freq_cis_imag = vb.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_imag")?; + let cos = freq_cis_real.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?; + let sin = freq_cis_imag.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?; Ok(Self { masks: Arc::new(Mutex::new(HashMap::new())), use_kv_cache, kvs: Arc::new(Mutex::new(vec![None; cfg.n_layers])), - cos: freq_cis_real, - sin: freq_cis_imag, + cos, + sin, device: vb.device().clone(), }) } @@ -110,16 +112,17 @@ struct CausalSelfAttention { impl CausalSelfAttention { fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { - let (b_sz, _, seq_len, n_embd) = x.dims4()?; + let (b_sz, seq_len, h, n_embd) = x.dims4()?; let cos = self.cache.cos.narrow(0, index_pos, seq_len)?; let sin = self.cache.sin.narrow(0, index_pos, seq_len)?; - let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2))?; - let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2))?; - let x0 = x.narrow(D::Minus1, 0, n_embd / 2)?; - let x1 = x.narrow(D::Minus1, n_embd / 2, n_embd / 2)?; + let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?; + let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?; + let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?; + let x0 = x.narrow(D::Minus1, 0, 1)?; + let x1 = x.narrow(D::Minus1, 1, 1)?; let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?; let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?; - let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?; + let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?; Ok(rope) } |