diff options
Diffstat (limited to 'candle-transformers/src/models/clip/vision_model.rs')
-rw-r--r-- | candle-transformers/src/models/clip/vision_model.rs | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/candle-transformers/src/models/clip/vision_model.rs b/candle-transformers/src/models/clip/vision_model.rs index 88992434..e64cab16 100644 --- a/candle-transformers/src/models/clip/vision_model.rs +++ b/candle-transformers/src/models/clip/vision_model.rs @@ -46,6 +46,19 @@ impl ClipVisionConfig { patch_size: 32, } } + pub fn clip_vit_large_patch14_336() -> Self { + Self { + embed_dim: 1024, + activation: Activation::QuickGelu, + intermediate_size: 4096, + num_hidden_layers: 24, + num_attention_heads: 16, + projection_dim: 768, + num_channels: 3, + image_size: 336, + patch_size: 14, + } + } } // https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L112 @@ -130,6 +143,17 @@ impl ClipVisionTransformer { pre_layer_norm, }) } + // required by LLaVA + pub fn output_hidden_states(&self, pixel_values: &Tensor) -> Result<Vec<Tensor>> { + let hidden_states = pixel_values + .apply(&self.embeddings)? + .apply(&self.pre_layer_norm)?; + let mut result = self.encoder.output_hidden_states(&hidden_states, None)?; + let encoder_outputs = result.last().unwrap(); + let pooled_output = encoder_outputs.i((.., 0, ..))?; + result.push(self.final_layer_norm.forward(&pooled_output)?.clone()); + Ok(result) + } } impl Module for ClipVisionTransformer { |