diff options
Diffstat (limited to 'candle-transformers/src/models/mobileclip.rs')
-rw-r--r-- | candle-transformers/src/models/mobileclip.rs | 89 |
1 files changed, 89 insertions, 0 deletions
diff --git a/candle-transformers/src/models/mobileclip.rs b/candle-transformers/src/models/mobileclip.rs new file mode 100644 index 00000000..4953d835 --- /dev/null +++ b/candle-transformers/src/models/mobileclip.rs @@ -0,0 +1,89 @@ +use super::fastvit; +use super::openclip::text_model; +use candle::{Result, Tensor, D}; +use candle_nn::{Func, VarBuilder}; + +#[derive(Clone, Debug)] +pub struct MobileClipModel { + text_model: text_model::OpenClipTextTransformer, + vision_model: Func<'static>, + text_projection: Tensor, + logit_scale: Tensor, +} + +#[derive(Clone, Debug)] +pub struct MobileClipConfig { + pub text_config: text_model::Config, + pub vision_config: fastvit::Config, + pub image_size: usize, +} + +impl MobileClipConfig { + pub fn s1() -> Self { + let text_config = text_model::Config::vit_base_patch32(); + let vision_config = fastvit::Config::mci1(); + + Self { + text_config, + vision_config, + image_size: 256, + } + } + pub fn s2() -> Self { + let text_config = text_model::Config::vit_base_patch32(); + let vision_config = fastvit::Config::mci2(); + + Self { + text_config, + vision_config, + image_size: 256, + } + } +} + +impl MobileClipModel { + pub fn new(vs: VarBuilder, c: &MobileClipConfig) -> Result<Self> { + let vision_model = fastvit::fastvit(&c.vision_config, 512, vs.pp("visual.trunk"))?; + let text_model = text_model::OpenClipTextTransformer::new(vs.pp("text"), &c.text_config)?; + + let text_projection = vs.get( + (c.text_config.embed_dim, c.text_config.projection_dim), + "text.text_projection", + )?; + + let logit_scale = vs.get(&[], "logit_scale")?; + Ok(Self { + text_model, + vision_model, + text_projection, + logit_scale, + }) + } + + pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> { + input_ids + .apply(&self.text_model)? + .matmul(&self.text_projection) + } + + pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> { + pixel_values.apply(&self.vision_model) + } + + pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> { + let image_features = self.get_image_features(pixel_values)?; + let text_features = self.get_text_features(input_ids)?; + let image_features_normalized = div_l2_norm(&image_features)?; + let text_features_normalized = div_l2_norm(&text_features)?; + let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?; + let logit_scale = self.logit_scale.exp()?; + let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?; + let logits_per_image = logits_per_text.t()?; + Ok((logits_per_text, logits_per_image)) + } +} + +pub fn div_l2_norm(v: &Tensor) -> Result<Tensor> { + let l2_norm = v.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?; + v.broadcast_div(&l2_norm) +} |