summaryrefslogtreecommitdiff
path: root/candle-examples/examples/clip
diff options
context:
space:
mode:
authorTigran Zhampeissov <81493298+Tigranchick@users.noreply.github.com>2024-03-28 17:44:12 +0500
committerGitHub <noreply@github.com>2024-03-28 13:44:12 +0100
commitb0340d72ec9dd8f3bb1778e5a7d73111e67a4393 (patch)
tree1c15c1c5edae072d2c0e5d4f15c139ff9377210d /candle-examples/examples/clip
parentb3484e7a5e8d8c613e2a444c6f056142fc1e758d (diff)
downloadcandle-b0340d72ec9dd8f3bb1778e5a7d73111e67a4393.tar.gz
candle-b0340d72ec9dd8f3bb1778e5a7d73111e67a4393.tar.bz2
candle-b0340d72ec9dd8f3bb1778e5a7d73111e67a4393.zip
CLIP model implementation with example (#1950)
* CLIP model implementation with example * CLIP Implementation fixes, batch images * CLIP model remove images from git * CLIP model remove unnecessary use of batch_indices
Diffstat (limited to 'candle-examples/examples/clip')
-rw-r--r--candle-examples/examples/clip/README.md46
-rw-r--r--candle-examples/examples/clip/main.rs202
2 files changed, 248 insertions, 0 deletions
diff --git a/candle-examples/examples/clip/README.md b/candle-examples/examples/clip/README.md
new file mode 100644
index 00000000..f0ee3b2c
--- /dev/null
+++ b/candle-examples/examples/clip/README.md
@@ -0,0 +1,46 @@
+Contrastive Language-Image Pre-Training
+
+Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
+pairs of images with related texts.
+
+https://github.com/openai/CLIP
+
+https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip
+
+## Running on an example on cpu
+
+```
+$ cargo run --example clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
+
+
+Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
+
+INFO clip: Probability: 0.0000% Text: a cycling race
+INFO clip: Probability: 0.0000% Text: a photo of two cats
+INFO clip: Probability: 100.0000% Text: a robot holding a candle
+
+Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
+
+INFO clip: Probability: 99.9999% Text: a cycling race
+INFO clip: Probability: 0.0001% Text: a photo of two cats
+INFO clip: Probability: 0.0000% Text: a robot holding a candle
+```
+
+## Running on an example with metal feature (mac)
+
+```
+$ cargo run --features metal --example clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
+
+
+Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
+
+INFO clip: Probability: 0.0000% Text: a cycling race
+INFO clip: Probability: 0.0000% Text: a photo of two cats
+INFO clip: Probability: 100.0000% Text: a robot holding a candle
+
+Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
+
+INFO clip: Probability: 99.9999% Text: a cycling race
+INFO clip: Probability: 0.0001% Text: a photo of two cats
+INFO clip: Probability: 0.0000% Text: a robot holding a candle
+```
diff --git a/candle-examples/examples/clip/main.rs b/candle-examples/examples/clip/main.rs
new file mode 100644
index 00000000..f301d211
--- /dev/null
+++ b/candle-examples/examples/clip/main.rs
@@ -0,0 +1,202 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+use anyhow::Error as E;
+use clap::Parser;
+
+use candle::{DType, Device, Tensor};
+use candle_nn::{ops::softmax, VarBuilder};
+use candle_transformers::models::clip;
+
+use tokenizers::Tokenizer;
+use tracing::info;
+
+#[derive(Parser)]
+struct Args {
+ #[arg(long)]
+ model: Option<String>,
+
+ #[arg(long)]
+ tokenizer: Option<String>,
+
+ #[arg(long, use_value_delimiter = true)]
+ images: Option<Vec<String>>,
+
+ #[arg(long)]
+ cpu: bool,
+
+ #[arg(long, use_value_delimiter = true)]
+ sequences: Option<Vec<String>>,
+}
+
+fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
+ let img = image::io::Reader::open(path)?.decode()?;
+ let (height, width) = (image_size, image_size);
+ let img = img.resize_to_fill(
+ width as u32,
+ height as u32,
+ image::imageops::FilterType::Triangle,
+ );
+
+ let img = img.to_rgb8();
+
+ let img = img.into_raw();
+ let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
+ .permute((2, 0, 1))?
+ .to_dtype(DType::F32)?
+ .affine(2. / 255., -1.)?;
+ // .unsqueeze(0)?;
+ Ok(img)
+}
+
+fn load_images<T: AsRef<std::path::Path>>(
+ paths: &Vec<T>,
+ image_size: usize,
+) -> anyhow::Result<Tensor> {
+ let mut images = vec![];
+
+ for path in paths {
+ let tensor = load_image(path, image_size)?;
+ images.push(tensor);
+ }
+
+ let images = Tensor::stack(&images, 0)?;
+
+ Ok(images)
+}
+
+pub fn main() -> anyhow::Result<()> {
+ // std::env::set_var("RUST_BACKTRACE", "full");
+
+ let args = Args::parse();
+
+ tracing_subscriber::fmt::init();
+
+ let model_file = match args.model {
+ None => {
+ let api = hf_hub::api::sync::Api::new()?;
+
+ let api = api.repo(hf_hub::Repo::with_revision(
+ "openai/clip-vit-base-patch32".to_string(),
+ hf_hub::RepoType::Model,
+ "refs/pr/15".to_string(),
+ ));
+
+ api.get("model.safetensors")?
+ }
+ Some(model) => model.into(),
+ };
+
+ let tokenizer = get_tokenizer(args.tokenizer)?;
+
+ let config = clip::ClipConfig::vit_base_patch32();
+
+ let device = candle_examples::device(args.cpu)?;
+
+ let vec_imgs = match args.images {
+ Some(imgs) => imgs,
+ None => vec![
+ "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
+ "candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
+ ],
+ };
+
+ // let image = load_image(args.image, config.image_size)?.to_device(&device)?;
+ let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
+
+ let vb =
+ unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
+
+ let model = clip::ClipModel::new(vb, &config)?;
+
+ let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
+
+ let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
+
+ let softmax_image = softmax(&logits_per_image, 1)?;
+
+ let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
+
+ info!("softmax_image_vec: {:?}", softmax_image_vec);
+
+ let probability_vec = softmax_image_vec
+ .iter()
+ .map(|v| v * 100.0)
+ .collect::<Vec<f32>>();
+
+ let probability_per_image = probability_vec.len() / vec_imgs.len();
+
+ for (i, img) in vec_imgs.iter().enumerate() {
+ let start = i * probability_per_image;
+ let end = start + probability_per_image;
+ let prob = &probability_vec[start..end];
+ info!("\n\nResults for image: {}\n", img);
+
+ for (i, p) in prob.iter().enumerate() {
+ info!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
+ }
+ }
+
+ Ok(())
+}
+
+pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
+ let tokenizer = match tokenizer {
+ None => {
+ let api = hf_hub::api::sync::Api::new()?;
+ let api = api.repo(hf_hub::Repo::with_revision(
+ "openai/clip-vit-base-patch32".to_string(),
+ hf_hub::RepoType::Model,
+ "refs/pr/15".to_string(),
+ ));
+ api.get("tokenizer.json")?
+ }
+ Some(file) => file.into(),
+ };
+
+ Tokenizer::from_file(tokenizer).map_err(E::msg)
+}
+
+pub fn tokenize_sequences(
+ sequences: Option<Vec<String>>,
+ tokenizer: &Tokenizer,
+ device: &Device,
+) -> anyhow::Result<(Tensor, Vec<String>)> {
+ let pad_id = *tokenizer
+ .get_vocab(true)
+ .get("<|endoftext|>")
+ .ok_or(E::msg("No pad token"))?;
+
+ let vec_seq = match sequences {
+ Some(seq) => seq,
+ None => vec![
+ "a cycling race".to_string(),
+ "a photo of two cats".to_string(),
+ "a robot holding a candle".to_string(),
+ ],
+ };
+
+ let mut tokens = vec![];
+
+ for seq in vec_seq.clone() {
+ let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
+ tokens.push(encoding.get_ids().to_vec());
+ }
+
+ let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0);
+
+ // Pad the sequences to have the same length
+ for token_vec in tokens.iter_mut() {
+ let len_diff = max_len - token_vec.len();
+ if len_diff > 0 {
+ token_vec.extend(vec![pad_id; len_diff]);
+ }
+ }
+
+ let input_ids = Tensor::new(tokens, device)?;
+
+ Ok((input_ids, vec_seq))
+}