diff options
Diffstat (limited to 'candle-examples/examples/llama2-c/main.rs')
-rw-r--r-- | candle-examples/examples/llama2-c/main.rs | 60 |
1 files changed, 54 insertions, 6 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index e752a494..77dbc677 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -7,6 +7,7 @@ extern crate accelerate_src; extern crate intel_mkl_src; mod model; +mod qmodel; mod training; mod weights; use clap::{Parser, Subcommand}; @@ -19,6 +20,7 @@ use std::io::Write; use tokenizers::Tokenizer; use model::{Config, Llama}; +use qmodel::QLlama; use weights::TransformerWeights; #[derive(Parser, Debug, Clone)] @@ -152,6 +154,20 @@ fn main() -> anyhow::Result<()> { Ok(()) } +enum Model { + Llama(Llama), + QLlama(QLlama), +} + +impl Model { + fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> { + match self { + Self::Llama(l) => Ok(l.forward(xs, pos)?), + Self::QLlama(l) => Ok(l.forward(xs, pos)?), + } + } +} + fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> { use std::io::BufRead; @@ -241,24 +257,56 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { let device = candle_examples::device(common_args.cpu)?; + let is_gguf = config_path.extension().map_or(false, |v| v == "gguf"); let is_safetensors = config_path .extension() .map_or(false, |v| v == "safetensors"); - let (vb, config) = if is_safetensors { + let (model, config) = if is_gguf { + let config = Config::tiny(); + let vb = qmodel::VarBuilder::from_gguf(config_path)?; + let freq_cis_real = vb + .get( + (config.seq_len, config.head_size() / 2), + "rot.freq_cis_real", + )? + .dequantize(&candle::Device::Cpu)?; + let freq_cis_imag = vb + .get( + (config.seq_len, config.head_size() / 2), + "rot.freq_cis_imag", + )? + .dequantize(&candle::Device::Cpu)?; + + let fake_vb = candle_nn::VarBuilder::from_tensors( + [ + ("freq_cis_real".to_string(), freq_cis_real), + ("freq_cis_imag".to_string(), freq_cis_imag), + ] + .into_iter() + .collect(), + candle::DType::F32, + &candle::Device::Cpu, + ); + let cache = model::Cache::new(true, &config, fake_vb)?; + let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?); + (model, config) + } else if is_safetensors { let config = Config::tiny(); let tensors = candle::safetensors::load(config_path, &device)?; let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device); - (vb, config) + let cache = model::Cache::new(true, &config, vb.pp("rot"))?; + let model = Model::Llama(Llama::load(vb, &cache, config.clone())?); + (model, config) } else { let mut file = std::fs::File::open(config_path)?; let config = Config::from_reader(&mut file)?; println!("{config:?}"); let weights = TransformerWeights::from_reader(&mut file, &config, &device)?; let vb = weights.var_builder(&config, &device)?; - (vb, config) + let cache = model::Cache::new(true, &config, vb.pp("rot"))?; + let model = Model::Llama(Llama::load(vb, &cache, config.clone())?); + (model, config) }; - let cache = model::Cache::new(true, &config, vb.pp("rot"))?; - let model = Llama::load(vb, &cache, config)?; println!("starting the inference loop"); let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p); @@ -273,7 +321,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { let start_gen = std::time::Instant::now(); for index in 0.. { - if tokens.len() >= model.config.seq_len { + if tokens.len() >= config.seq_len { break; } let context_size = if index > 0 { 1 } else { tokens.len() }; |