diff options
Diffstat (limited to 'candle-examples/examples/llama2-c/main.rs')
-rw-r--r-- | candle-examples/examples/llama2-c/main.rs | 131 |
1 files changed, 74 insertions, 57 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 2e762f98..9b6d1316 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -13,6 +13,7 @@ use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use candle::{DType, Device, Error, IndexOp, Layout, Shape, Tensor}; use candle_nn::{Embedding, Linear, VarBuilder}; use candle_transformers::generation::LogitsProcessor; +use std::io::Write; use model::{Config, Llama}; @@ -38,21 +39,33 @@ struct TransformerWeights { freq_cis_imag: Tensor, // (seq_len, head_size/2) } -impl Config { - fn read_i32<R: std::io::Read>(r: &mut R) -> Result<i32> { - let mut buf = [0u8; 4]; - r.read_exact(&mut buf)?; - Ok(i32::from_le_bytes(buf)) - } +fn read_i32<R: std::io::Read>(r: &mut R) -> Result<i32> { + let mut buf = [0u8; 4]; + r.read_exact(&mut buf)?; + Ok(i32::from_le_bytes(buf)) +} +fn read_tensor<R: std::io::Read, S: Into<Shape>>( + r: &mut R, + shape: S, + dev: &Device, +) -> Result<Tensor> { + let shape = shape.into(); + let mut data_t = vec![0f32; shape.elem_count()]; + r.read_f32_into::<LittleEndian>(&mut data_t)?; + let tensor = Tensor::from_vec(data_t, shape, dev)?; + Ok(tensor) +} + +impl Config { fn from_reader<R: std::io::Read>(r: &mut R) -> Result<Self> { - let dim = Self::read_i32(r)? as usize; - let hidden_dim = Self::read_i32(r)? as usize; - let n_layers = Self::read_i32(r)? as usize; - let n_heads = Self::read_i32(r)? as usize; - let n_kv_heads = Self::read_i32(r)? as usize; - let vocab_size = Self::read_i32(r)? as usize; - let seq_len = Self::read_i32(r)? as usize; + let dim = read_i32(r)? as usize; + let hidden_dim = read_i32(r)? as usize; + let n_layers = read_i32(r)? as usize; + let n_heads = read_i32(r)? as usize; + let n_kv_heads = read_i32(r)? as usize; + let vocab_size = read_i32(r)? as usize; + let seq_len = read_i32(r)? as usize; Ok(Self { dim, hidden_dim, @@ -71,33 +84,21 @@ impl Config { } impl TransformerWeights { - fn read_tensor<R: std::io::Read, S: Into<Shape>>( - r: &mut R, - shape: S, - dev: &Device, - ) -> Result<Tensor> { - let shape = shape.into(); - let mut data_t = vec![0f32; shape.elem_count()]; - r.read_f32_into::<LittleEndian>(&mut data_t)?; - let tensor = Tensor::from_vec(data_t, shape, dev)?; - Ok(tensor) - } - fn from_reader<R: std::io::Read>(r: &mut R, c: &Config, dev: &Device) -> Result<Self> { - let token_embedding_table = Self::read_tensor(r, (c.vocab_size, c.dim), dev)?; - let rms_att_weight = Self::read_tensor(r, (c.n_layers, c.dim), dev)?; - let wq = Self::read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; - let wk = Self::read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; - let wv = Self::read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; - let wo = Self::read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; - let rms_ffn_weight = Self::read_tensor(r, (c.n_layers, c.dim), dev)?; - let w1 = Self::read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?; - let w2 = Self::read_tensor(r, (c.n_layers, c.dim, c.hidden_dim), dev)?; - let w3 = Self::read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?; - let rms_final_weight = Self::read_tensor(r, c.dim, dev)?; + let token_embedding_table = read_tensor(r, (c.vocab_size, c.dim), dev)?; + let rms_att_weight = read_tensor(r, (c.n_layers, c.dim), dev)?; + let wq = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; + let wk = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; + let wv = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; + let wo = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; + let rms_ffn_weight = read_tensor(r, (c.n_layers, c.dim), dev)?; + let w1 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?; + let w2 = read_tensor(r, (c.n_layers, c.dim, c.hidden_dim), dev)?; + let w3 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?; + let rms_final_weight = read_tensor(r, c.dim, dev)?; let head_size = c.head_size(); - let freq_cis_real = Self::read_tensor(r, (c.seq_len, head_size / 2), dev)?; - let freq_cis_imag = Self::read_tensor(r, (c.seq_len, head_size / 2), dev)?; + let freq_cis_real = read_tensor(r, (c.seq_len, head_size / 2), dev)?; + let freq_cis_imag = read_tensor(r, (c.seq_len, head_size / 2), dev)?; Ok(Self { token_embedding_table, rms_att_weight, @@ -181,13 +182,36 @@ struct Args { /// Config file in binary format. #[arg(long)] config: String, + + /// Tokenizer config file in binary format. + #[arg(long)] + tokenizer: String, + + /// The temperature used to generate samples. + #[arg(long)] + temperature: Option<f64>, +} + +struct Tokenizer { + tokens: Vec<String>, +} + +impl Tokenizer { + fn from_reader<R: std::io::Read>(r: &mut R, c: &Config) -> Result<Self> { + let mut tokens = Vec::with_capacity(c.vocab_size); + for _token_index in 0..c.vocab_size { + let token_len = read_i32(r)?; + let mut token = vec![0u8; token_len as usize]; + r.read_exact(&mut token); + tokens.push(String::from_utf8_lossy(&token).into_owned()) + } + Ok(Self { tokens }) + } } fn main() -> anyhow::Result<()> { let args = Args::parse(); let device = candle_examples::device(args.cpu)?; - let t = Tensor::arange(0f32, 14f32, &device)?.reshape((2, 7))?; - println!("{t}"); let mut file = std::fs::File::open(&args.config)?; let config = Config::from_reader(&mut file)?; println!("config: {config:?}"); @@ -196,13 +220,15 @@ fn main() -> anyhow::Result<()> { let cache = model::Cache::new(true, &config, vb.pp("rot"))?; let model = Llama::load(vb, &cache, &config)?; + let mut file = std::fs::File::open(&args.tokenizer)?; + let tokenizer = Tokenizer::from_reader(&mut file, &config)?; + println!("starting the inference loop"); - let mut logits_processor = LogitsProcessor::new(299792458, None); - let mut new_tokens: Vec<u32> = vec![]; - let start_gen = std::time::Instant::now(); + let mut logits_processor = LogitsProcessor::new(299792458, args.temperature); let mut index_pos = 0; let mut tokens = vec![1u32]; + let start_gen = std::time::Instant::now(); for index in 0..config.seq_len - 10 { let start_gen = std::time::Instant::now(); let context_size = if cache.use_kv_cache && index > 0 { @@ -218,23 +244,14 @@ fn main() -> anyhow::Result<()> { let next_token = logits_processor.sample(&logits)?; tokens.push(next_token); - new_tokens.push(next_token); - println!("> {:?}", start_gen.elapsed()); - println!( - "{} token: {} '{}'", - index + 1, - next_token, - 0, - // tokenizer.decode(vec![next_token], true).map_err(E::msg)? - ); + print!("{}", tokenizer.tokens[next_token as usize]); + std::io::stdout().flush()?; } let dt = start_gen.elapsed(); println!( - "{} tokens generated ({} token/s)\n----\n{}\n----", - config.seq_len, - config.seq_len as f64 / dt.as_secs_f64(), - 0, - // tokenizer.decode(new_tokens, true).map_err(E::msg)? + "\n{} tokens generated ({:.2} token/s)\n", + tokens.len(), + tokens.len() as f64 / dt.as_secs_f64(), ); Ok(()) } |