// An implementation of different Granite models https://www.ibm.com/granite #[cfg(feature = "accelerate")] extern crate accelerate_src; #[cfg(feature = "mkl")] extern crate intel_mkl_src; use anyhow::{bail, Error as E, Result}; use clap::{Parser, ValueEnum}; use candle::{DType, Tensor}; use candle_nn::VarBuilder; use candle_transformers::generation::{LogitsProcessor, Sampling}; use hf_hub::{api::sync::Api, Repo, RepoType}; use std::io::Write; use candle_transformers::models::granite as model; use model::{Granite, GraniteConfig}; use std::time::Instant; const EOS_TOKEN: &str = ""; const DEFAULT_PROMPT: &str = "How Fault Tolerant Quantum Computers will help humanity?"; #[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] enum GraniteModel { Granite7bInstruct, } #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { /// Run on CPU rather than on GPU. #[arg(long)] cpu: bool, /// The temperature used to generate samples. #[arg(long, default_value_t = 0.8)] temperature: f64, /// Nucleus sampling probability cutoff. #[arg(long)] top_p: Option, /// Only sample among the top K samples. #[arg(long)] top_k: Option, /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, /// The length of the sample to generate (in tokens). #[arg(short = 'n', long, default_value_t = 10000)] sample_len: usize, /// Disable the key-value cache. #[arg(long)] no_kv_cache: bool, /// The initial prompt. #[arg(long)] prompt: Option, /// Use different dtype than f16 #[arg(long)] dtype: Option, /// Enable tracing (generates a trace-timestamp.json file). #[arg(long)] tracing: bool, #[arg(long)] model_id: Option, #[arg(long)] revision: Option, #[arg(long, default_value = "granite7b-instruct")] model_type: GraniteModel, #[arg(long)] use_flash_attn: bool, /// Penalty to be applied for repeating tokens, 1. means no penalty. #[arg(long, default_value_t = 1.1)] repeat_penalty: f32, /// The context size to consider for the repeat penalty. #[arg(long, default_value_t = 128)] repeat_last_n: usize, } fn main() -> Result<()> { use tokenizers::Tokenizer; use tracing_chrome::ChromeLayerBuilder; use tracing_subscriber::prelude::*; let args = Args::parse(); let _guard = if args.tracing { let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); tracing_subscriber::registry().with(chrome_layer).init(); Some(guard) } else { None }; let device = candle_examples::device(args.cpu)?; let dtype = match args.dtype.as_deref() { Some("f16") => DType::F16, Some("bf16") => DType::BF16, Some("f32") => DType::F32, Some(dtype) => bail!("Unsupported dtype {dtype}"), None => DType::F16, }; let (granite, tokenizer_filename, mut cache, config) = { let api = Api::new()?; let model_id = args.model_id.unwrap_or_else(|| match args.model_type { GraniteModel::Granite7bInstruct => "ibm-granite/granite-7b-instruct".to_string(), }); println!("loading the model weights from {model_id}"); let revision = args.revision.unwrap_or("main".to_string()); let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); let tokenizer_filename = api.get("tokenizer.json")?; let config_filename = api.get("config.json")?; let config: GraniteConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?; let config = config.into_config(args.use_flash_attn); let filenames = match args.model_type { GraniteModel::Granite7bInstruct => { candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")? } }; let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; ( Granite::load(vb, &config)?, tokenizer_filename, cache, config, ) }; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let eos_token_id = config.eos_token_id.or_else(|| { tokenizer .token_to_id(EOS_TOKEN) .map(model::GraniteEosToks::Single) }); let default_prompt = match args.model_type { GraniteModel::Granite7bInstruct => DEFAULT_PROMPT, }; let prompt = args.prompt.as_ref().map_or(default_prompt, |p| p.as_str()); let mut tokens = tokenizer .encode(prompt, true) .map_err(E::msg)? .get_ids() .to_vec(); let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer); println!("Starting the inference loop:"); print!("{prompt}"); let mut logits_processor = { let temperature = args.temperature; let sampling = if temperature <= 0. { Sampling::ArgMax } else { match (args.top_k, args.top_p) { (None, None) => Sampling::All { temperature }, (Some(k), None) => Sampling::TopK { k, temperature }, (None, Some(p)) => Sampling::TopP { p, temperature }, (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, } }; LogitsProcessor::from_sampling(args.seed, sampling) }; let mut start_gen = std::time::Instant::now(); let mut index_pos = 0; let mut token_generated = 0; let use_cache_kv = cache.use_kv_cache; (0..args.sample_len) .inspect(|index| { if *index == 1 { start_gen = Instant::now(); } }) .try_for_each(|index| -> Result<()> { let (context_size, context_index) = if use_cache_kv && index > 0 { (1, index_pos) } else { (tokens.len(), 0) }; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; let logits = granite .forward(&input, context_index, &mut cache)? .squeeze(0)?; let logits = if args.repeat_penalty == 1. { logits } else { let start_at = tokens.len().saturating_sub(args.repeat_last_n); candle_transformers::utils::apply_repeat_penalty( &logits, args.repeat_penalty, &tokens[start_at..], )? }; index_pos += ctxt.len(); let next_token = logits_processor.sample(&logits)?; token_generated += 1; tokens.push(next_token); if let Some(model::GraniteEosToks::Single(eos_tok_id)) = eos_token_id { if next_token == eos_tok_id { return Err(E::msg("EOS token found")); } } else if let Some(model::GraniteEosToks::Multiple(ref eos_ids)) = eos_token_id { if eos_ids.contains(&next_token) { return Err(E::msg("EOS token found")); } } if let Some(t) = tokenizer.next_token(next_token)? { print!("{t}"); std::io::stdout().flush()?; } Ok(()) }) .unwrap_or(()); if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? { print!("{rest}"); } let dt = start_gen.elapsed(); println!( "\n\n{} tokens generated ({} token/s)\n", token_generated, (token_generated - 1) as f64 / dt.as_secs_f64(), ); Ok(()) }