summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama2-c/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/llama2-c/main.rs')
-rw-r--r--candle-examples/examples/llama2-c/main.rs60
1 files changed, 54 insertions, 6 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs
index e752a494..77dbc677 100644
--- a/candle-examples/examples/llama2-c/main.rs
+++ b/candle-examples/examples/llama2-c/main.rs
@@ -7,6 +7,7 @@ extern crate accelerate_src;
extern crate intel_mkl_src;
mod model;
+mod qmodel;
mod training;
mod weights;
use clap::{Parser, Subcommand};
@@ -19,6 +20,7 @@ use std::io::Write;
use tokenizers::Tokenizer;
use model::{Config, Llama};
+use qmodel::QLlama;
use weights::TransformerWeights;
#[derive(Parser, Debug, Clone)]
@@ -152,6 +154,20 @@ fn main() -> anyhow::Result<()> {
Ok(())
}
+enum Model {
+ Llama(Llama),
+ QLlama(QLlama),
+}
+
+impl Model {
+ fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> {
+ match self {
+ Self::Llama(l) => Ok(l.forward(xs, pos)?),
+ Self::QLlama(l) => Ok(l.forward(xs, pos)?),
+ }
+ }
+}
+
fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
use std::io::BufRead;
@@ -241,24 +257,56 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let device = candle_examples::device(common_args.cpu)?;
+ let is_gguf = config_path.extension().map_or(false, |v| v == "gguf");
let is_safetensors = config_path
.extension()
.map_or(false, |v| v == "safetensors");
- let (vb, config) = if is_safetensors {
+ let (model, config) = if is_gguf {
+ let config = Config::tiny();
+ let vb = qmodel::VarBuilder::from_gguf(config_path)?;
+ let freq_cis_real = vb
+ .get(
+ (config.seq_len, config.head_size() / 2),
+ "rot.freq_cis_real",
+ )?
+ .dequantize(&candle::Device::Cpu)?;
+ let freq_cis_imag = vb
+ .get(
+ (config.seq_len, config.head_size() / 2),
+ "rot.freq_cis_imag",
+ )?
+ .dequantize(&candle::Device::Cpu)?;
+
+ let fake_vb = candle_nn::VarBuilder::from_tensors(
+ [
+ ("freq_cis_real".to_string(), freq_cis_real),
+ ("freq_cis_imag".to_string(), freq_cis_imag),
+ ]
+ .into_iter()
+ .collect(),
+ candle::DType::F32,
+ &candle::Device::Cpu,
+ );
+ let cache = model::Cache::new(true, &config, fake_vb)?;
+ let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
+ (model, config)
+ } else if is_safetensors {
let config = Config::tiny();
let tensors = candle::safetensors::load(config_path, &device)?;
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
- (vb, config)
+ let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
+ let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
+ (model, config)
} else {
let mut file = std::fs::File::open(config_path)?;
let config = Config::from_reader(&mut file)?;
println!("{config:?}");
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?;
- (vb, config)
+ let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
+ let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
+ (model, config)
};
- let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
- let model = Llama::load(vb, &cache, config)?;
println!("starting the inference loop");
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p);
@@ -273,7 +321,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let start_gen = std::time::Instant::now();
for index in 0.. {
- if tokens.len() >= model.config.seq_len {
+ if tokens.len() >= config.seq_len {
break;
}
let context_size = if index > 0 { 1 } else { tokens.len() };