diff options
Diffstat (limited to 'candle-examples/examples/mamba/main.rs')
-rw-r--r-- | candle-examples/examples/mamba/main.rs | 299 |
1 files changed, 299 insertions, 0 deletions
diff --git a/candle-examples/examples/mamba/main.rs b/candle-examples/examples/mamba/main.rs new file mode 100644 index 00000000..4802f960 --- /dev/null +++ b/candle-examples/examples/mamba/main.rs @@ -0,0 +1,299 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::{Error as E, Result}; +use clap::{Parser, ValueEnum}; + +use candle_transformers::models::mamba::{Config, Model, State}; + +use candle::{DType, Device, Tensor}; +use candle_examples::token_output_stream::TokenOutputStream; +use candle_nn::VarBuilder; +use candle_transformers::generation::LogitsProcessor; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::Tokenizer; + +struct TextGeneration { + model: Model, + config: Config, + device: Device, + tokenizer: TokenOutputStream, + logits_processor: LogitsProcessor, + repeat_penalty: f32, + repeat_last_n: usize, +} + +impl TextGeneration { + #[allow(clippy::too_many_arguments)] + fn new( + model: Model, + config: Config, + tokenizer: Tokenizer, + seed: u64, + temp: Option<f64>, + top_p: Option<f64>, + repeat_penalty: f32, + repeat_last_n: usize, + device: &Device, + ) -> Self { + let logits_processor = LogitsProcessor::new(seed, temp, top_p); + Self { + model, + config, + tokenizer: TokenOutputStream::new(tokenizer), + logits_processor, + repeat_penalty, + repeat_last_n, + device: device.clone(), + } + } + + fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { + use std::io::Write; + self.tokenizer.clear(); + let mut tokens = self + .tokenizer + .tokenizer() + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let mut generated_tokens = 0usize; + let eos_token = match self.tokenizer.get_token("<|endoftext|>") { + Some(token) => token, + None => anyhow::bail!("cannot find the </s> token"), + }; + let mut state = State::new(1, &self.config, &self.device)?; + let mut next_logits = None; + for &t in tokens.iter() { + let input = Tensor::new(&[t], &self.device)?; + let logits = self.model.forward(&input, &mut state)?; + next_logits = Some(logits); + if let Some(t) = self.tokenizer.next_token(t)? { + print!("{t}") + } + } + std::io::stdout().flush()?; + + let start_gen = std::time::Instant::now(); + for _ in 0..sample_len { + let logits = match next_logits.as_ref() { + Some(logits) => logits, + None => anyhow::bail!("cannot work on an empty prompt"), + }; + let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; + let logits = if self.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(self.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + self.repeat_penalty, + &tokens[start_at..], + )? + }; + let next_token = self.logits_processor.sample(&logits)?; + tokens.push(next_token); + generated_tokens += 1; + if next_token == eos_token { + break; + } + if let Some(t) = self.tokenizer.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + + let input = Tensor::new(&[next_token], &self.device)?; + next_logits = Some(self.model.forward(&input, &mut state)?) + } + let dt = start_gen.elapsed(); + if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? { + print!("{rest}"); + } + std::io::stdout().flush()?; + println!( + "\n{generated_tokens} tokens generated ({:.2} token/s)", + generated_tokens as f64 / dt.as_secs_f64(), + ); + Ok(()) + } +} + +#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)] +enum Which { + Mamba130m, + Mamba370m, + Mamba790m, + Mamba1_4b, + Mamba2_8b, + Mamba2_8bSlimPj, +} + +impl std::fmt::Display for Which { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +impl Which { + fn model_id(&self) -> &'static str { + match self { + Self::Mamba130m => "state-spaces/mamba-130m", + Self::Mamba370m => "state-spaces/mamba-370m", + Self::Mamba790m => "state-spaces/mamba-790m", + Self::Mamba1_4b => "state-spaces/mamba-1.4b", + Self::Mamba2_8b => "state-spaces/mamba-2.8b", + Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'", + } + } + + fn revision(&self) -> &'static str { + match self { + Self::Mamba130m + | Self::Mamba370m + | Self::Mamba790m + | Self::Mamba1_4b + | Self::Mamba2_8bSlimPj => "refs/pr/1", + Self::Mamba2_8b => "refs/pr/4", + } + } +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + #[arg(long)] + prompt: String, + + /// The temperature used to generate samples. + #[arg(long)] + temperature: Option<f64>, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option<f64>, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The length of the sample to generate (in tokens). + #[arg(long, short = 'n', default_value_t = 5000)] + sample_len: usize, + + #[arg(long, default_value = "mamba130m")] + which: Which, + + #[arg(long)] + model_id: Option<String>, + + #[arg(long)] + revision: Option<String>, + + #[arg(long)] + tokenizer_file: Option<String>, + + #[arg(long)] + weight_files: Option<String>, + + #[arg(long)] + config_file: Option<String>, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature.unwrap_or(0.), + args.repeat_penalty, + args.repeat_last_n + ); + + let start = std::time::Instant::now(); + let api = Api::new()?; + let repo = api.repo(Repo::with_revision( + args.model_id + .unwrap_or_else(|| args.which.model_id().to_string()), + RepoType::Model, + args.revision + .unwrap_or_else(|| args.which.revision().to_string()), + )); + let tokenizer_filename = match args.tokenizer_file { + Some(file) => std::path::PathBuf::from(file), + None => api + .model("EleutherAI/gpt-neox-20b".to_string()) + .get("tokenizer.json")?, + }; + let config_filename = match args.config_file { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("config.json")?, + }; + let filenames = match args.weight_files { + Some(files) => files + .split(',') + .map(std::path::PathBuf::from) + .collect::<Vec<_>>(), + None => { + vec![repo.get("model.safetensors")?] + } + }; + println!("retrieved the files in {:?}", start.elapsed()); + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + let start = std::time::Instant::now(); + let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?; + let device = candle_examples::device(args.cpu)?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; + let model = Model::new(&config, vb.pp("backbone"))?; + println!("loaded the model in {:?}", start.elapsed()); + + let mut pipeline = TextGeneration::new( + model, + config, + tokenizer, + args.seed, + args.temperature, + args.top_p, + args.repeat_penalty, + args.repeat_last_n, + &device, + ); + pipeline.run(&args.prompt, args.sample_len)?; + Ok(()) +} |