diff options
Diffstat (limited to 'candle-examples/examples/phi/main.rs')
-rw-r--r-- | candle-examples/examples/phi/main.rs | 163 |
1 files changed, 163 insertions, 0 deletions
diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs new file mode 100644 index 00000000..4b290cd8 --- /dev/null +++ b/candle-examples/examples/phi/main.rs @@ -0,0 +1,163 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::{Error as E, Result}; +use clap::Parser; + +use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as Model}; + +use candle::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::generation::LogitsProcessor; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::Tokenizer; + +struct TextGeneration { + model: Model, + device: Device, + tokenizer: Tokenizer, + logits_processor: LogitsProcessor, +} + +impl TextGeneration { + fn new( + model: Model, + tokenizer: Tokenizer, + seed: u64, + temp: Option<f64>, + top_p: Option<f64>, + device: &Device, + ) -> Self { + let logits_processor = LogitsProcessor::new(seed, temp, top_p); + Self { + model, + tokenizer, + logits_processor, + device: device.clone(), + } + } + + fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { + use std::io::Write; + println!("starting the inference loop"); + print!("{prompt}"); + std::io::stdout().flush()?; + let mut tokens = self + .tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + + let mut new_tokens = vec![]; + let start_gen = std::time::Instant::now(); + for index in 0..sample_len { + let context_size = if index > 0 { 1 } else { tokens.len() }; + let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; + let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; + let logits = self.model.forward(&input)?; + let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; + + let next_token = self.logits_processor.sample(&logits)?; + tokens.push(next_token); + new_tokens.push(next_token); + let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?; + print!("{token}"); + std::io::stdout().flush()?; + } + let dt = start_gen.elapsed(); + println!( + "{sample_len} tokens generated ({:.3} token/s)", + sample_len as f64 / dt.as_secs_f64(), + ); + Ok(()) + } +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + #[arg(long)] + prompt: String, + + /// The temperature used to generate samples. + #[arg(long)] + temperature: Option<f64>, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option<f64>, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The length of the sample to generate (in tokens). + #[arg(long, default_value_t = 100)] + sample_len: usize, + + #[arg(long, default_value = "microsoft/phi-1_5")] + model_id: String, + + #[arg(long, default_value = "refs/pr/18")] + revision: String, + + #[arg(long)] + weight_file: Option<String>, +} + +fn main() -> Result<()> { + let args = Args::parse(); + + let start = std::time::Instant::now(); + let api = Api::new()?; + let repo = api.repo(Repo::with_revision( + args.model_id, + RepoType::Model, + args.revision, + )); + let tokenizer_filename = repo.get("tokenizer.json")?; + let filenames = match args.weight_file { + Some(weight_file) => vec![std::path::PathBuf::from(weight_file)], + None => ["model.safetensors"] + .iter() + .map(|f| repo.get(f)) + .collect::<std::result::Result<Vec<_>, _>>()?, + }; + println!("retrieved the files in {:?}", start.elapsed()); + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + let weights = filenames + .iter() + .map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? })) + .collect::<Result<Vec<_>>>()?; + let weights = weights + .iter() + .map(|f| Ok(f.deserialize()?)) + .collect::<Result<Vec<_>>>()?; + + let start = std::time::Instant::now(); + let device = candle_examples::device(args.cpu)?; + let vb = VarBuilder::from_safetensors(weights, DType::F32, &device); + let config = Config::v1_5(); + let model = Model::new(&config, vb)?; + println!("loaded the model in {:?}", start.elapsed()); + + let mut pipeline = TextGeneration::new( + model, + tokenizer, + args.seed, + args.temperature, + args.top_p, + &device, + ); + pipeline.run(&args.prompt, args.sample_len)?; + Ok(()) +} |