diff options
Diffstat (limited to 'candle-examples/examples/llama2-c')
-rw-r--r-- | candle-examples/examples/llama2-c/main.rs | 28 | ||||
-rw-r--r-- | candle-examples/examples/llama2-c/training.rs | 124 | ||||
-rw-r--r-- | candle-examples/examples/llama2-c/weights.rs | 25 |
3 files changed, 43 insertions, 134 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 8b64fdd2..418218b6 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -1,5 +1,8 @@ // https://github.com/karpathy/llama2.c +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + #[cfg(feature = "mkl")] extern crate intel_mkl_src; @@ -27,7 +30,7 @@ struct InferenceCmd { #[arg(long, default_value = "")] prompt: String, - /// Config file in binary format. + /// Config file in binary or safetensors format. #[arg(long)] config: Option<String>, @@ -200,7 +203,7 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> { Some(inputs.and_then(|inputs| targets.map(|targets| (inputs, targets)))) } }); - let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size); + let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); for inp_tgt in batch_iter { let (inp, tgt) = inp_tgt?; let logits = model.forward(&inp, 0)?; @@ -225,11 +228,22 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { let device = candle_examples::device(common_args.cpu)?; - let mut file = std::fs::File::open(config_path)?; - let config = Config::from_reader(&mut file)?; - println!("{config:?}"); - let weights = TransformerWeights::from_reader(&mut file, &config, &device)?; - let vb = weights.var_builder(&config, &device)?; + let is_safetensors = config_path + .extension() + .map_or(false, |v| v == "safetensors"); + let (vb, config) = if is_safetensors { + let config = Config::tiny(); + let tensors = candle::safetensors::load(config_path, &device)?; + let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device); + (vb, config) + } else { + let mut file = std::fs::File::open(config_path)?; + let config = Config::from_reader(&mut file)?; + println!("{config:?}"); + let weights = TransformerWeights::from_reader(&mut file, &config, &device)?; + let vb = weights.var_builder(&config, &device)?; + (vb, config) + }; let cache = model::Cache::new(true, &config, vb.pp("rot"))?; let model = Llama::load(vb, &cache, config)?; diff --git a/candle-examples/examples/llama2-c/training.rs b/candle-examples/examples/llama2-c/training.rs index e55c686c..3e93c786 100644 --- a/candle-examples/examples/llama2-c/training.rs +++ b/candle-examples/examples/llama2-c/training.rs @@ -1,118 +1,6 @@ -#![allow(dead_code)] -#![allow(unused)] use crate::model::{Cache, Config, Llama}; -use candle::{DType, Device, Result, Tensor}; - -pub struct Dataset { - valid_tokens: Vec<memmap2::Mmap>, - train_tokens: Vec<memmap2::Mmap>, -} - -fn mmap_file(p: &std::path::PathBuf) -> Result<memmap2::Mmap> { - let file = std::fs::File::open(p)?; - let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? }; - Ok(mmap) -} - -impl Dataset { - pub fn new<P: AsRef<std::path::Path>>(dir: P) -> Result<Self> { - let dir = dir.as_ref(); - let mut bin_files = vec![]; - for file in std::fs::read_dir(dir)?.flatten() { - let file = file.path(); - if let Some(extension) = file.extension() { - if extension == "bin" { - bin_files.push(file) - } - } - } - if bin_files.len() < 2 { - candle::bail!("found less than two bin files in {:?}", dir) - } - bin_files.sort(); - let valid_tokens = mmap_file(&bin_files[0])?; - let train_tokens = bin_files[1..] - .iter() - .map(mmap_file) - .collect::<Result<Vec<_>>>()?; - Ok(Self { - valid_tokens: vec![valid_tokens], - train_tokens, - }) - } -} - -struct DatasetRandomIter<'a> { - all_tokens: &'a [memmap2::Mmap], - tokens: Vec<&'a memmap2::Mmap>, - current_tokens: &'a memmap2::Mmap, - indexes_in_bytes: Vec<usize>, - seq_len: usize, - device: Device, -} - -impl<'a> DatasetRandomIter<'a> { - pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self { - use rand::seq::SliceRandom; - use rand::thread_rng; - - let all_tokens = if valid { - &ds.valid_tokens - } else { - &ds.train_tokens - }; - let mut tokens = all_tokens.iter().collect::<Vec<_>>(); - tokens.shuffle(&mut thread_rng()); - let current_tokens = tokens.pop().unwrap(); - let seq_len_in_bytes = seq_len * 2; - let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes) - .step_by(seq_len_in_bytes) - .collect::<Vec<_>>(); - indexes_in_bytes.shuffle(&mut thread_rng()); - Self { - all_tokens, - tokens, - current_tokens, - indexes_in_bytes, - seq_len, - device, - } - } -} - -impl<'a> Iterator for DatasetRandomIter<'a> { - type Item = Result<(Tensor, Tensor)>; - - fn next(&mut self) -> Option<Self::Item> { - use byteorder::{LittleEndian, ReadBytesExt}; - use rand::seq::SliceRandom; - use rand::thread_rng; - - let seq_len = self.seq_len; - if self.indexes_in_bytes.is_empty() { - if self.tokens.is_empty() { - self.tokens = self.all_tokens.iter().collect(); - self.tokens.shuffle(&mut thread_rng()); - } - self.current_tokens = self.tokens.pop().unwrap(); - let seq_len_in_bytes = self.seq_len * 2; - self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes) - .step_by(seq_len_in_bytes) - .collect::<Vec<_>>(); - self.indexes_in_bytes.shuffle(&mut thread_rng()); - } - let start_idx = self.indexes_in_bytes.pop().unwrap(); - let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)]; - let mut tokens = vec![0u16; bytes.len() / 2]; - if let Err(err) = std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens) { - return Some(Err(err.into())); - } - let tokens = tokens.into_iter().map(|v| v as u32).collect::<Vec<_>>(); - let inputs = Tensor::new(&tokens[..seq_len], &self.device); - let targets = Tensor::new(&tokens[1..], &self.device); - Some(candle::error::zip(inputs, targets)) - } -} +use candle::{DType, Device, Result}; +use candle_datasets::nlp::tinystories::{Dataset, DatasetRandomIter}; fn valid_loss( dataset: &Dataset, @@ -121,7 +9,7 @@ fn valid_loss( device: &Device, ) -> Result<f64> { let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone()); - let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size); + let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); let mut sum_ce = 0f64; let mut cnt = 0usize; for inp_tgt in batch_iter.take(50) { @@ -139,14 +27,14 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> { let dataset = Dataset::new(&args.pretokenized_dir)?; println!( "loaded dataset, train: {} files, valid: {} files", - dataset.train_tokens.len(), - dataset.valid_tokens.len() + dataset.train_tokens(), + dataset.valid_tokens() ); let varmap = candle_nn::VarMap::new(); let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device); let config = Config::tiny(); let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone()); - let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size); + let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); let cache = Cache::new(false, &config, vb.pp("rot"))?; let model = Llama::load(vb, &cache, config)?; diff --git a/candle-examples/examples/llama2-c/weights.rs b/candle-examples/examples/llama2-c/weights.rs index ae1fd6d9..b78418ce 100644 --- a/candle-examples/examples/llama2-c/weights.rs +++ b/candle-examples/examples/llama2-c/weights.rs @@ -104,7 +104,14 @@ impl TransformerWeights { }) } - pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder> { + pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder<'static>> { + // TODO: As of 2023-08-04, gemm is slower than expected when multiplying a matrix of + // size (1, k) with the transpose of a matrix of size (k, n) as it ends up transposing the + // second matrix back. We detect this case here and as a temporary hack make the weight + // matrix column major rather than row major. This ends up speeding up text generation from + // 120 token/s to 220 token/s on a Ryzen 2600X. + let tr = device.is_cpu() && !candle::utils::has_mkl(); + let tr = |x: Tensor| if tr { x.t()?.contiguous()?.t() } else { Ok(x) }; let mut ws = std::collections::HashMap::new(); let mut insert = |name: &str, t: Tensor| { ws.insert(name.to_string(), t); @@ -115,36 +122,36 @@ impl TransformerWeights { "model.embed_tokens.weight", self.token_embedding_table.clone(), ); - insert("lm_head.weight", self.token_embedding_table.clone()); + insert("lm_head.weight", tr(self.token_embedding_table.clone())?); insert("model.norm.weight", self.rms_final_weight.clone()); for layer in 0..cfg.n_layers { ws.insert( format!("model.layers.{layer}.self_attn.q_proj.weight"), - self.wq.i(layer)?, + tr(self.wq.i(layer)?)?, ); ws.insert( format!("model.layers.{layer}.self_attn.k_proj.weight"), - self.wk.i(layer)?, + tr(self.wk.i(layer)?)?, ); ws.insert( format!("model.layers.{layer}.self_attn.v_proj.weight"), - self.wv.i(layer)?, + tr(self.wv.i(layer)?)?, ); ws.insert( format!("model.layers.{layer}.self_attn.o_proj.weight"), - self.wo.i(layer)?, + tr(self.wo.i(layer)?)?, ); ws.insert( format!("model.layers.{layer}.mlp.gate_proj.weight"), - self.w1.i(layer)?, + tr(self.w1.i(layer)?)?, ); ws.insert( format!("model.layers.{layer}.mlp.down_proj.weight"), - self.w2.i(layer)?, + tr(self.w2.i(layer)?)?, ); ws.insert( format!("model.layers.{layer}.mlp.up_proj.weight"), - self.w3.i(layer)?, + tr(self.w3.i(layer)?)?, ); ws.insert( format!("model.layers.{layer}.input_layernorm.weight"), |