diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-09-28 23:48:00 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-28 23:48:00 +0200 |
commit | 261ed65f36c3f66ab33335850797ff473cbe4dd0 (patch) | |
tree | 678383e6025639996ebe8444b207a285c1a5445d /candle-examples/examples/clip | |
parent | 62525e83526465b2d7969c8d3d0213b491b7ccc8 (diff) | |
download | candle-261ed65f36c3f66ab33335850797ff473cbe4dd0.tar.gz candle-261ed65f36c3f66ab33335850797ff473cbe4dd0.tar.bz2 candle-261ed65f36c3f66ab33335850797ff473cbe4dd0.zip |
Add the SigLIP model. (#2515)
* Add the SigLIP model.
* Add more to the forward pass of the vision model.
* Complete the forward pass.
* Add the siglip example.
* Fix.
* Another fix.
* Get everything in place.
* Add a readme.
Diffstat (limited to 'candle-examples/examples/clip')
-rw-r--r-- | candle-examples/examples/clip/main.rs | 44 |
1 files changed, 3 insertions, 41 deletions
diff --git a/candle-examples/examples/clip/main.rs b/candle-examples/examples/clip/main.rs index d057663d..273edb6a 100644 --- a/candle-examples/examples/clip/main.rs +++ b/candle-examples/examples/clip/main.rs @@ -12,7 +12,6 @@ use candle_nn::{ops::softmax, VarBuilder}; use candle_transformers::models::clip; use tokenizers::Tokenizer; -use tracing::info; #[derive(Parser)] struct Args { @@ -40,15 +39,12 @@ fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow:: height as u32, image::imageops::FilterType::Triangle, ); - let img = img.to_rgb8(); - let img = img.into_raw(); let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)? .permute((2, 0, 1))? .to_dtype(DType::F32)? .affine(2. / 255., -1.)?; - // .unsqueeze(0)?; Ok(img) } @@ -57,24 +53,16 @@ fn load_images<T: AsRef<std::path::Path>>( image_size: usize, ) -> anyhow::Result<Tensor> { let mut images = vec![]; - for path in paths { let tensor = load_image(path, image_size)?; images.push(tensor); } - let images = Tensor::stack(&images, 0)?; - Ok(images) } pub fn main() -> anyhow::Result<()> { - // std::env::set_var("RUST_BACKTRACE", "full"); - let args = Args::parse(); - - tracing_subscriber::fmt::init(); - let model_file = match args.model { None => { let api = hf_hub::api::sync::Api::new()?; @@ -89,13 +77,9 @@ pub fn main() -> anyhow::Result<()> { } Some(model) => model.into(), }; - let tokenizer = get_tokenizer(args.tokenizer)?; - let config = clip::ClipConfig::vit_base_patch32(); - let device = candle_examples::device(args.cpu)?; - let vec_imgs = match args.images { Some(imgs) => imgs, None => vec![ @@ -103,43 +87,29 @@ pub fn main() -> anyhow::Result<()> { "candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(), ], }; - - // let image = load_image(args.image, config.image_size)?.to_device(&device)?; let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?; - let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? }; - let model = clip::ClipModel::new(vb, &config)?; - let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?; - let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?; - let softmax_image = softmax(&logits_per_image, 1)?; - let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?; - - info!("softmax_image_vec: {:?}", softmax_image_vec); - + println!("softmax_image_vec: {:?}", softmax_image_vec); let probability_vec = softmax_image_vec .iter() .map(|v| v * 100.0) .collect::<Vec<f32>>(); - let probability_per_image = probability_vec.len() / vec_imgs.len(); - for (i, img) in vec_imgs.iter().enumerate() { let start = i * probability_per_image; let end = start + probability_per_image; let prob = &probability_vec[start..end]; - info!("\n\nResults for image: {}\n", img); - + println!("\n\nResults for image: {}\n", img); for (i, p) in prob.iter().enumerate() { - info!("Probability: {:.4}% Text: {} ", p, vec_seq[i]); + println!("Probability: {:.4}% Text: {} ", p, vec_seq[i]); } } - Ok(()) } @@ -156,7 +126,6 @@ pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> { } Some(file) => file.into(), }; - Tokenizer::from_file(tokenizer).map_err(E::msg) } @@ -169,7 +138,6 @@ pub fn tokenize_sequences( .get_vocab(true) .get("<|endoftext|>") .ok_or(E::msg("No pad token"))?; - let vec_seq = match sequences { Some(seq) => seq, None => vec![ @@ -178,16 +146,12 @@ pub fn tokenize_sequences( "a robot holding a candle".to_string(), ], }; - let mut tokens = vec![]; - for seq in vec_seq.clone() { let encoding = tokenizer.encode(seq, true).map_err(E::msg)?; tokens.push(encoding.get_ids().to_vec()); } - let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0); - // Pad the sequences to have the same length for token_vec in tokens.iter_mut() { let len_diff = max_len - token_vec.len(); @@ -195,8 +159,6 @@ pub fn tokenize_sequences( token_vec.extend(vec![pad_id; len_diff]); } } - let input_ids = Tensor::new(tokens, device)?; - Ok((input_ids, vec_seq)) } |