diff options
Diffstat (limited to 'candle-examples/examples/moondream/main.rs')
-rw-r--r-- | candle-examples/examples/moondream/main.rs | 93 |
1 files changed, 77 insertions, 16 deletions
diff --git a/candle-examples/examples/moondream/main.rs b/candle-examples/examples/moondream/main.rs index 3e0f6d57..008346f0 100644 --- a/candle-examples/examples/moondream/main.rs +++ b/candle-examples/examples/moondream/main.rs @@ -9,11 +9,19 @@ use clap::Parser; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; -use candle_transformers::{generation::LogitsProcessor, models::moondream}; +use candle_transformers::{ + generation::LogitsProcessor, + models::{moondream, quantized_moondream}, +}; use tokenizers::Tokenizer; +enum Model { + Moondream(moondream::Model), + Quantized(quantized_moondream::Model), +} + struct TextGeneration { - model: moondream::Model, + model: Model, device: Device, tokenizer: Tokenizer, logits_processor: LogitsProcessor, @@ -25,7 +33,7 @@ struct TextGeneration { impl TextGeneration { #[allow(clippy::too_many_arguments)] fn new( - model: moondream::Model, + model: Model, tokenizer: Tokenizer, seed: u64, temp: Option<f64>, @@ -64,6 +72,14 @@ impl TextGeneration { let mut tokens = tokens.get_ids().to_vec(); let mut generated_tokens = 0usize; + // Moondream tokenizer bos_token is "<|endoftext|>" + // https://huggingface.co/vikhyatk/moondream2/blob/main/special_tokens_map.json + let bos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") { + Some(token) => *token, + None => anyhow::bail!("cannot find the BOS token"), + }; + // eos_token is "END" + // https://github.com/vikhyat/moondream/blob/a9d788a20d1543fb1479edc54106e88cff7759d3/moondream/moondream.py#L100 let eos_token = match self.tokenizer.get_vocab(true).get("END") { Some(token) => *token, None => anyhow::bail!("cannot find the EOS token"), @@ -75,11 +91,24 @@ impl TextGeneration { let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; let logits = if index > 0 { - self.model.text_model.forward(&input)? + match self.model { + Model::Moondream(ref mut model) => model.text_model.forward(&input)?, + Model::Quantized(ref mut model) => model.text_model.forward(&input)?, + } } else { - self.model - .text_model - .forward_with_img(&input, image_embeds)? + let bos_token = Tensor::new(&[bos_token], &self.device)?.unsqueeze(0)?; + match self.model { + Model::Moondream(ref mut model) => { + model + .text_model + .forward_with_img(&bos_token, &input, image_embeds)? + } + Model::Quantized(ref mut model) => { + model + .text_model + .forward_with_img(&bos_token, &input, image_embeds)? + } + } }; let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; let logits = if self.repeat_penalty == 1. { @@ -142,7 +171,7 @@ struct Args { top_p: Option<f64>, /// The seed to use when generating random samples. - #[arg(long, default_value_t = 299792458)] + #[arg(long, default_value_t = 0)] seed: u64, #[arg(long, default_value_t = 5000)] @@ -156,13 +185,16 @@ struct Args { #[arg(long, default_value_t = 64)] repeat_last_n: usize, - #[arg(long, default_value = "vikhyatk/moondream2")] - model_id: String, + #[arg(long)] + model_id: Option<String>, #[arg(long, default_value = "main")] revision: String, #[arg(long)] + quantized: bool, + + #[arg(long)] model_file: Option<String>, #[arg(long)] @@ -216,14 +248,30 @@ async fn main() -> anyhow::Result<()> { let start = std::time::Instant::now(); let api = hf_hub::api::tokio::Api::new()?; + let model_id = match args.model_id { + Some(model_id) => model_id.to_string(), + None => { + if args.quantized { + "santiagomed/candle-moondream".to_string() + } else { + "vikhyatk/moondream2".to_string() + } + } + }; let repo = api.repo(hf_hub::Repo::with_revision( - args.model_id, + model_id, hf_hub::RepoType::Model, args.revision, )); let model_file = match args.model_file { Some(m) => m.into(), - None => repo.get("model.safetensors").await?, + None => { + if args.quantized { + repo.get("model-q4_0.gguf").await? + } else { + repo.get("model.safetensors").await? + } + } }; let tokenizer = match args.tokenizer_file { Some(m) => m.into(), @@ -234,22 +282,35 @@ async fn main() -> anyhow::Result<()> { let start = std::time::Instant::now(); let device = candle_examples::device(args.cpu)?; - let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; let config = moondream::Config::v2(); - let model = moondream::Model::new(&config, vb)?; + let model = if args.quantized { + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf( + &model_file, + &device, + )?; + let model = quantized_moondream::Model::new(&config, vb)?; + Model::Quantized(model) + } else { + let vb = + unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; + let model = moondream::Model::new(&config, vb)?; + Model::Moondream(model) + }; println!("loaded the model in {:?}", start.elapsed()); let start = std::time::Instant::now(); let image = load_image(args.image)?.to_device(&device)?; let image_embeds = image.unsqueeze(0)?; - let image_embeds = image_embeds.apply(model.vision_encoder())?; + let image_embeds = match model { + Model::Moondream(ref m) => image_embeds.apply(m.vision_encoder())?, + Model::Quantized(ref m) => image_embeds.apply(m.vision_encoder())?, + }; println!( "loaded and encoded the image {image:?} in {:?}", start.elapsed() ); let prompt = format!("\n\nQuestion: {0}\n\nAnswer:", args.prompt); - let mut pipeline = TextGeneration::new( model, tokenizer, |