summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/siglip.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/siglip.rs')
-rw-r--r--candle-transformers/src/models/siglip.rs54
1 files changed, 54 insertions, 0 deletions
diff --git a/candle-transformers/src/models/siglip.rs b/candle-transformers/src/models/siglip.rs
index a3280a86..63b6635d 100644
--- a/candle-transformers/src/models/siglip.rs
+++ b/candle-transformers/src/models/siglip.rs
@@ -83,6 +83,60 @@ impl TransformerConfig for VisionConfig {
}
}
+impl VisionConfig {
+ pub fn paligemma_3b_224() -> Self {
+ Self {
+ // https://huggingface.co/google/paligemma-3b-pt-224/blob/main/config.json
+ patch_size: 14,
+ num_attention_heads: 16,
+ num_hidden_layers: 27,
+ hidden_size: 1152,
+ intermediate_size: 4304,
+ image_size: 224, // num_image_tokens: (224 / 14)^2 = 256
+ // Default values.
+ num_channels: 3,
+ hidden_act: candle_nn::Activation::GeluPytorchTanh,
+ layer_norm_eps: 1e-6,
+ }
+ }
+
+ pub fn paligemma_3b_448() -> Self {
+ Self {
+ // https://huggingface.co/google/paligemma-3b-pt-448/blob/main/config.json
+ patch_size: 14,
+ num_attention_heads: 16,
+ num_hidden_layers: 27,
+ hidden_size: 1152,
+ intermediate_size: 4304,
+ image_size: 448, // num_image_tokens: (448 / 14)^2 = 1024
+ // Default values.
+ num_channels: 3,
+ hidden_act: candle_nn::Activation::GeluPytorchTanh,
+ layer_norm_eps: 1e-6,
+ }
+ }
+
+ pub fn paligemma_3b_896() -> Self {
+ Self {
+ // https://huggingface.co/google/paligemma-3b-pt-448/blob/main/config.json
+ patch_size: 14,
+ num_attention_heads: 16,
+ num_hidden_layers: 27,
+ hidden_size: 1152,
+ intermediate_size: 4304,
+ image_size: 896, // num_image_tokens: (896 / 14)^2 = 4096
+ // Default values.
+ num_channels: 3,
+ hidden_act: candle_nn::Activation::GeluPytorchTanh,
+ layer_norm_eps: 1e-6,
+ }
+ }
+
+ pub fn num_patches(&self) -> usize {
+ (self.image_size / self.patch_size).pow(2)
+ }
+}
+
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L228
#[derive(serde::Deserialize, Clone, Debug)]
pub struct Config {