summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/clip/vision_model.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/clip/vision_model.rs')
-rw-r--r--candle-transformers/src/models/clip/vision_model.rs24
1 files changed, 24 insertions, 0 deletions
diff --git a/candle-transformers/src/models/clip/vision_model.rs b/candle-transformers/src/models/clip/vision_model.rs
index 88992434..e64cab16 100644
--- a/candle-transformers/src/models/clip/vision_model.rs
+++ b/candle-transformers/src/models/clip/vision_model.rs
@@ -46,6 +46,19 @@ impl ClipVisionConfig {
patch_size: 32,
}
}
+ pub fn clip_vit_large_patch14_336() -> Self {
+ Self {
+ embed_dim: 1024,
+ activation: Activation::QuickGelu,
+ intermediate_size: 4096,
+ num_hidden_layers: 24,
+ num_attention_heads: 16,
+ projection_dim: 768,
+ num_channels: 3,
+ image_size: 336,
+ patch_size: 14,
+ }
+ }
}
// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L112
@@ -130,6 +143,17 @@ impl ClipVisionTransformer {
pre_layer_norm,
})
}
+ // required by LLaVA
+ pub fn output_hidden_states(&self, pixel_values: &Tensor) -> Result<Vec<Tensor>> {
+ let hidden_states = pixel_values
+ .apply(&self.embeddings)?
+ .apply(&self.pre_layer_norm)?;
+ let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
+ let encoder_outputs = result.last().unwrap();
+ let pooled_output = encoder_outputs.i((.., 0, ..))?;
+ result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
+ Ok(result)
+ }
}
impl Module for ClipVisionTransformer {