summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama2-c
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/llama2-c')
-rw-r--r--candle-examples/examples/llama2-c/main.rs28
-rw-r--r--candle-examples/examples/llama2-c/training.rs124
-rw-r--r--candle-examples/examples/llama2-c/weights.rs25
3 files changed, 43 insertions, 134 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs
index 8b64fdd2..418218b6 100644
--- a/candle-examples/examples/llama2-c/main.rs
+++ b/candle-examples/examples/llama2-c/main.rs
@@ -1,5 +1,8 @@
// https://github.com/karpathy/llama2.c
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
@@ -27,7 +30,7 @@ struct InferenceCmd {
#[arg(long, default_value = "")]
prompt: String,
- /// Config file in binary format.
+ /// Config file in binary or safetensors format.
#[arg(long)]
config: Option<String>,
@@ -200,7 +203,7 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
Some(inputs.and_then(|inputs| targets.map(|targets| (inputs, targets))))
}
});
- let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
+ let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
for inp_tgt in batch_iter {
let (inp, tgt) = inp_tgt?;
let logits = model.forward(&inp, 0)?;
@@ -225,11 +228,22 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let device = candle_examples::device(common_args.cpu)?;
- let mut file = std::fs::File::open(config_path)?;
- let config = Config::from_reader(&mut file)?;
- println!("{config:?}");
- let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
- let vb = weights.var_builder(&config, &device)?;
+ let is_safetensors = config_path
+ .extension()
+ .map_or(false, |v| v == "safetensors");
+ let (vb, config) = if is_safetensors {
+ let config = Config::tiny();
+ let tensors = candle::safetensors::load(config_path, &device)?;
+ let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
+ (vb, config)
+ } else {
+ let mut file = std::fs::File::open(config_path)?;
+ let config = Config::from_reader(&mut file)?;
+ println!("{config:?}");
+ let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
+ let vb = weights.var_builder(&config, &device)?;
+ (vb, config)
+ };
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;
diff --git a/candle-examples/examples/llama2-c/training.rs b/candle-examples/examples/llama2-c/training.rs
index e55c686c..3e93c786 100644
--- a/candle-examples/examples/llama2-c/training.rs
+++ b/candle-examples/examples/llama2-c/training.rs
@@ -1,118 +1,6 @@
-#![allow(dead_code)]
-#![allow(unused)]
use crate::model::{Cache, Config, Llama};
-use candle::{DType, Device, Result, Tensor};
-
-pub struct Dataset {
- valid_tokens: Vec<memmap2::Mmap>,
- train_tokens: Vec<memmap2::Mmap>,
-}
-
-fn mmap_file(p: &std::path::PathBuf) -> Result<memmap2::Mmap> {
- let file = std::fs::File::open(p)?;
- let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
- Ok(mmap)
-}
-
-impl Dataset {
- pub fn new<P: AsRef<std::path::Path>>(dir: P) -> Result<Self> {
- let dir = dir.as_ref();
- let mut bin_files = vec![];
- for file in std::fs::read_dir(dir)?.flatten() {
- let file = file.path();
- if let Some(extension) = file.extension() {
- if extension == "bin" {
- bin_files.push(file)
- }
- }
- }
- if bin_files.len() < 2 {
- candle::bail!("found less than two bin files in {:?}", dir)
- }
- bin_files.sort();
- let valid_tokens = mmap_file(&bin_files[0])?;
- let train_tokens = bin_files[1..]
- .iter()
- .map(mmap_file)
- .collect::<Result<Vec<_>>>()?;
- Ok(Self {
- valid_tokens: vec![valid_tokens],
- train_tokens,
- })
- }
-}
-
-struct DatasetRandomIter<'a> {
- all_tokens: &'a [memmap2::Mmap],
- tokens: Vec<&'a memmap2::Mmap>,
- current_tokens: &'a memmap2::Mmap,
- indexes_in_bytes: Vec<usize>,
- seq_len: usize,
- device: Device,
-}
-
-impl<'a> DatasetRandomIter<'a> {
- pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self {
- use rand::seq::SliceRandom;
- use rand::thread_rng;
-
- let all_tokens = if valid {
- &ds.valid_tokens
- } else {
- &ds.train_tokens
- };
- let mut tokens = all_tokens.iter().collect::<Vec<_>>();
- tokens.shuffle(&mut thread_rng());
- let current_tokens = tokens.pop().unwrap();
- let seq_len_in_bytes = seq_len * 2;
- let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes)
- .step_by(seq_len_in_bytes)
- .collect::<Vec<_>>();
- indexes_in_bytes.shuffle(&mut thread_rng());
- Self {
- all_tokens,
- tokens,
- current_tokens,
- indexes_in_bytes,
- seq_len,
- device,
- }
- }
-}
-
-impl<'a> Iterator for DatasetRandomIter<'a> {
- type Item = Result<(Tensor, Tensor)>;
-
- fn next(&mut self) -> Option<Self::Item> {
- use byteorder::{LittleEndian, ReadBytesExt};
- use rand::seq::SliceRandom;
- use rand::thread_rng;
-
- let seq_len = self.seq_len;
- if self.indexes_in_bytes.is_empty() {
- if self.tokens.is_empty() {
- self.tokens = self.all_tokens.iter().collect();
- self.tokens.shuffle(&mut thread_rng());
- }
- self.current_tokens = self.tokens.pop().unwrap();
- let seq_len_in_bytes = self.seq_len * 2;
- self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes)
- .step_by(seq_len_in_bytes)
- .collect::<Vec<_>>();
- self.indexes_in_bytes.shuffle(&mut thread_rng());
- }
- let start_idx = self.indexes_in_bytes.pop().unwrap();
- let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)];
- let mut tokens = vec![0u16; bytes.len() / 2];
- if let Err(err) = std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens) {
- return Some(Err(err.into()));
- }
- let tokens = tokens.into_iter().map(|v| v as u32).collect::<Vec<_>>();
- let inputs = Tensor::new(&tokens[..seq_len], &self.device);
- let targets = Tensor::new(&tokens[1..], &self.device);
- Some(candle::error::zip(inputs, targets))
- }
-}
+use candle::{DType, Device, Result};
+use candle_datasets::nlp::tinystories::{Dataset, DatasetRandomIter};
fn valid_loss(
dataset: &Dataset,
@@ -121,7 +9,7 @@ fn valid_loss(
device: &Device,
) -> Result<f64> {
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
- let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
+ let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
let mut sum_ce = 0f64;
let mut cnt = 0usize;
for inp_tgt in batch_iter.take(50) {
@@ -139,14 +27,14 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
let dataset = Dataset::new(&args.pretokenized_dir)?;
println!(
"loaded dataset, train: {} files, valid: {} files",
- dataset.train_tokens.len(),
- dataset.valid_tokens.len()
+ dataset.train_tokens(),
+ dataset.valid_tokens()
);
let varmap = candle_nn::VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
let config = Config::tiny();
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
- let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
+ let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
let cache = Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;
diff --git a/candle-examples/examples/llama2-c/weights.rs b/candle-examples/examples/llama2-c/weights.rs
index ae1fd6d9..b78418ce 100644
--- a/candle-examples/examples/llama2-c/weights.rs
+++ b/candle-examples/examples/llama2-c/weights.rs
@@ -104,7 +104,14 @@ impl TransformerWeights {
})
}
- pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder> {
+ pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder<'static>> {
+ // TODO: As of 2023-08-04, gemm is slower than expected when multiplying a matrix of
+ // size (1, k) with the transpose of a matrix of size (k, n) as it ends up transposing the
+ // second matrix back. We detect this case here and as a temporary hack make the weight
+ // matrix column major rather than row major. This ends up speeding up text generation from
+ // 120 token/s to 220 token/s on a Ryzen 2600X.
+ let tr = device.is_cpu() && !candle::utils::has_mkl();
+ let tr = |x: Tensor| if tr { x.t()?.contiguous()?.t() } else { Ok(x) };
let mut ws = std::collections::HashMap::new();
let mut insert = |name: &str, t: Tensor| {
ws.insert(name.to_string(), t);
@@ -115,36 +122,36 @@ impl TransformerWeights {
"model.embed_tokens.weight",
self.token_embedding_table.clone(),
);
- insert("lm_head.weight", self.token_embedding_table.clone());
+ insert("lm_head.weight", tr(self.token_embedding_table.clone())?);
insert("model.norm.weight", self.rms_final_weight.clone());
for layer in 0..cfg.n_layers {
ws.insert(
format!("model.layers.{layer}.self_attn.q_proj.weight"),
- self.wq.i(layer)?,
+ tr(self.wq.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.self_attn.k_proj.weight"),
- self.wk.i(layer)?,
+ tr(self.wk.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.self_attn.v_proj.weight"),
- self.wv.i(layer)?,
+ tr(self.wv.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.self_attn.o_proj.weight"),
- self.wo.i(layer)?,
+ tr(self.wo.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.mlp.gate_proj.weight"),
- self.w1.i(layer)?,
+ tr(self.w1.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.mlp.down_proj.weight"),
- self.w2.i(layer)?,
+ tr(self.w2.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.mlp.up_proj.weight"),
- self.w3.i(layer)?,
+ tr(self.w3.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.input_layernorm.weight"),