summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/siglip.rs
blob: 63b6635dc119c5dd64bf0ac64a6d49daa5ca0762 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
use crate::models::clip::div_l2_norm;
use candle::{IndexOp, Module, Result, Tensor, D};
use candle_nn::{layer_norm, linear, LayerNorm, Linear, VarBuilder};

// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L27
#[derive(serde::Deserialize, Clone, Debug)]
pub struct TextConfig {
    pub vocab_size: usize,
    pub hidden_size: usize,
    pub intermediate_size: usize,
    pub num_hidden_layers: usize,
    pub num_attention_heads: usize,
    pub max_position_embeddings: usize,
    pub hidden_act: candle_nn::Activation,
    pub layer_norm_eps: f64,
    pub pad_token_id: u32,
    pub bos_token_id: u32,
    pub eos_token_id: u32,
}

// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L132
#[derive(serde::Deserialize, Clone, Debug)]
pub struct VisionConfig {
    pub hidden_size: usize,
    pub intermediate_size: usize,
    pub num_hidden_layers: usize,
    pub num_attention_heads: usize,
    pub num_channels: usize,
    pub image_size: usize,
    pub patch_size: usize,
    pub hidden_act: candle_nn::Activation,
    pub layer_norm_eps: f64,
}

trait TransformerConfig {
    fn hidden_size(&self) -> usize;
    fn intermediate_size(&self) -> usize;
    fn num_attention_heads(&self) -> usize;
    fn num_hidden_layers(&self) -> usize;
    fn layer_norm_eps(&self) -> f64;
    fn hidden_act(&self) -> candle_nn::Activation;
}

impl TransformerConfig for TextConfig {
    fn hidden_size(&self) -> usize {
        self.hidden_size
    }
    fn intermediate_size(&self) -> usize {
        self.intermediate_size
    }
    fn num_attention_heads(&self) -> usize {
        self.num_attention_heads
    }
    fn num_hidden_layers(&self) -> usize {
        self.num_hidden_layers
    }
    fn layer_norm_eps(&self) -> f64 {
        self.layer_norm_eps
    }
    fn hidden_act(&self) -> candle_nn::Activation {
        self.hidden_act
    }
}

impl TransformerConfig for VisionConfig {
    fn hidden_size(&self) -> usize {
        self.hidden_size
    }
    fn intermediate_size(&self) -> usize {
        self.intermediate_size
    }
    fn num_attention_heads(&self) -> usize {
        self.num_attention_heads
    }
    fn num_hidden_layers(&self) -> usize {
        self.num_hidden_layers
    }
    fn layer_norm_eps(&self) -> f64 {
        self.layer_norm_eps
    }
    fn hidden_act(&self) -> candle_nn::Activation {
        self.hidden_act
    }
}

impl VisionConfig {
    pub fn paligemma_3b_224() -> Self {
        Self {
            // https://huggingface.co/google/paligemma-3b-pt-224/blob/main/config.json
            patch_size: 14,
            num_attention_heads: 16,
            num_hidden_layers: 27,
            hidden_size: 1152,
            intermediate_size: 4304,
            image_size: 224, // num_image_tokens: (224 / 14)^2 = 256
            // Default values.
            num_channels: 3,
            hidden_act: candle_nn::Activation::GeluPytorchTanh,
            layer_norm_eps: 1e-6,
        }
    }

    pub fn paligemma_3b_448() -> Self {
        Self {
            // https://huggingface.co/google/paligemma-3b-pt-448/blob/main/config.json
            patch_size: 14,
            num_attention_heads: 16,
            num_hidden_layers: 27,
            hidden_size: 1152,
            intermediate_size: 4304,
            image_size: 448, // num_image_tokens: (448 / 14)^2 = 1024
            // Default values.
            num_channels: 3,
            hidden_act: candle_nn::Activation::GeluPytorchTanh,
            layer_norm_eps: 1e-6,
        }
    }

    pub fn paligemma_3b_896() -> Self {
        Self {
            // https://huggingface.co/google/paligemma-3b-pt-448/blob/main/config.json
            patch_size: 14,
            num_attention_heads: 16,
            num_hidden_layers: 27,
            hidden_size: 1152,
            intermediate_size: 4304,
            image_size: 896, // num_image_tokens: (896 / 14)^2 = 4096
            // Default values.
            num_channels: 3,
            hidden_act: candle_nn::Activation::GeluPytorchTanh,
            layer_norm_eps: 1e-6,
        }
    }

    pub fn num_patches(&self) -> usize {
        (self.image_size / self.patch_size).pow(2)
    }
}

// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L228
#[derive(serde::Deserialize, Clone, Debug)]
pub struct Config {
    pub text_config: TextConfig,
    pub vision_config: VisionConfig,
}

impl Config {
    pub fn base_patch16_224() -> Self {
        let text_config = TextConfig {
            // https://huggingface.co/google/siglip-base-patch16-224/blob/main/config.json
            hidden_size: 768,
            intermediate_size: 3072,
            num_attention_heads: 12,
            vocab_size: 32000,
            // Default values.
            pad_token_id: 1,
            bos_token_id: 49406,
            eos_token_id: 49407,
            layer_norm_eps: 1e-6,
            hidden_act: candle_nn::Activation::GeluPytorchTanh,
            max_position_embeddings: 64,
            num_hidden_layers: 12,
        };
        let vision_config = VisionConfig {
            patch_size: 16,
            // Default values.
            hidden_size: 768,
            intermediate_size: 3072,
            num_hidden_layers: 12,
            num_attention_heads: 12,
            num_channels: 3,
            image_size: 224,
            hidden_act: candle_nn::Activation::GeluPytorchTanh,
            layer_norm_eps: 1e-6,
        };
        Self {
            text_config,
            vision_config,
        }
    }
}

#[derive(Clone, Debug)]
struct MultiheadAttention {
    q_proj: Linear,
    k_proj: Linear,
    v_proj: Linear,
    out_proj: Linear,
    num_heads: usize,
}

impl MultiheadAttention {
    fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
        let h = cfg.hidden_size;
        let num_heads = cfg.num_attention_heads;
        let w_in_proj = vb.get((3 * h, h), "in_proj_weight")?.chunk(3, 0)?;
        let b_in_proj = vb.get(3 * h, "in_proj_bias")?.chunk(3, 0)?;
        let q_proj = Linear::new(w_in_proj[0].clone(), Some(b_in_proj[0].clone()));
        let k_proj = Linear::new(w_in_proj[1].clone(), Some(b_in_proj[1].clone()));
        let v_proj = Linear::new(w_in_proj[2].clone(), Some(b_in_proj[2].clone()));
        let out_proj = linear(h, h, vb.pp("out_proj"))?;
        Ok(Self {
            q_proj,
            k_proj,
            v_proj,
            out_proj,
            num_heads,
        })
    }

    fn separate_heads(&self, x: &Tensor) -> Result<Tensor> {
        let (b, n, c) = x.dims3()?;
        x.reshape((b, n, self.num_heads, c / self.num_heads))?
            .transpose(1, 2)?
            .contiguous()
    }

    fn recombine_heads(&self, x: &Tensor) -> Result<Tensor> {
        let (b, n_heads, n_tokens, c_per_head) = x.dims4()?;
        x.transpose(1, 2)?
            .reshape((b, n_tokens, n_heads * c_per_head))
    }

    fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
        let q = self.q_proj.forward(&q.contiguous()?)?;
        let k = self.k_proj.forward(&k.contiguous()?)?;
        let v = self.v_proj.forward(&v.contiguous()?)?;

        let q = self.separate_heads(&q)?;
        let k = self.separate_heads(&k)?;
        let v = self.separate_heads(&v)?;

        let (_, _, _, c_per_head) = q.dims4()?;
        let attn = (q.matmul(&k.t()?)? / (c_per_head as f64).sqrt())?;
        let attn = candle_nn::ops::softmax_last_dim(&attn)?;

        let out = attn.matmul(&v)?;
        self.recombine_heads(&out)?.apply(&self.out_proj)
    }
}

#[derive(Debug, Clone)]
struct MultiheadAttentionPoolingHead {
    probe: Tensor,
    attention: MultiheadAttention,
    layernorm: LayerNorm,
    mlp: Mlp,
}

impl MultiheadAttentionPoolingHead {
    fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
        let mlp = Mlp::new(cfg, vb.pp("mlp"))?;
        let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("layernorm"))?;
        let probe = vb.get((1, 1, cfg.hidden_size), "probe")?;
        let attention = MultiheadAttention::new(cfg, vb.pp("attention"))?;
        Ok(Self {
            probe,
            attention,
            layernorm,
            mlp,
        })
    }
}

impl Module for MultiheadAttentionPoolingHead {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        let batch_size = xs.dim(0)?;
        let probe = self.probe.repeat((batch_size, 1, 1))?;
        let xs = self.attention.forward(&probe, xs, xs)?;
        let residual = &xs;
        let xs = xs.apply(&self.layernorm)?.apply(&self.mlp)?;
        (xs + residual)?.i((.., 0))
    }
}

#[derive(Debug, Clone)]
struct Attention {
    q_proj: Linear,
    k_proj: Linear,
    v_proj: Linear,
    out_proj: Linear,
    num_heads: usize,
    head_dim: usize,
    scale: f64,
}

impl Attention {
    fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
        let embed_dim = cfg.hidden_size();
        let q_proj = linear(embed_dim, embed_dim, vb.pp("q_proj"))?;
        let k_proj = linear(embed_dim, embed_dim, vb.pp("k_proj"))?;
        let v_proj = linear(embed_dim, embed_dim, vb.pp("v_proj"))?;
        let out_proj = linear(embed_dim, embed_dim, vb.pp("out_proj"))?;
        let num_heads = cfg.num_attention_heads();
        let head_dim = embed_dim / num_heads;
        Ok(Self {
            q_proj,
            k_proj,
            v_proj,
            out_proj,
            num_heads,
            head_dim,
            scale: (head_dim as f64).powf(-0.5),
        })
    }

    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
        let (batch_size, q_len, _) = xs.dims3()?;
        let query_states = xs.apply(&self.q_proj)?;
        let key_states = xs.apply(&self.k_proj)?;
        let value_states = xs.apply(&self.v_proj)?;

        let shape = (batch_size, q_len, self.num_heads, self.head_dim);
        let query_states = query_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
        let key_states = key_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
        let value_states = value_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;

        let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?;
        let attn_weights = match attention_mask {
            None => attn_weights,
            Some(mask) => attn_weights.broadcast_add(mask)?,
        };
        // The original implementation upcasts to f32 but candle_nn::ops::softmax should handle this properly.
        let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
        let attn_outputs = attn_weights
            .matmul(&value_states)?
            .transpose(1, 2)?
            .reshape((batch_size, q_len, ()))?
            .apply(&self.out_proj)?;
        Ok(attn_outputs)
    }
}

// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/modeling_siglip.py#L599
#[derive(Debug, Clone)]
struct Mlp {
    fc1: Linear,
    fc2: Linear,
    activation_fn: candle_nn::Activation,
}

impl Mlp {
    fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
        let hidden_size = cfg.hidden_size();
        let intermediate_size = cfg.intermediate_size();
        let fc1 = candle_nn::linear(hidden_size, intermediate_size, vb.pp("fc1"))?;
        let fc2 = candle_nn::linear(intermediate_size, hidden_size, vb.pp("fc2"))?;
        Ok(Self {
            fc1,
            fc2,
            activation_fn: cfg.hidden_act(),
        })
    }
}

impl Module for Mlp {
    fn forward(&self, xs: &candle::Tensor) -> Result<candle::Tensor> {
        xs.apply(&self.fc1)?
            .apply(&self.activation_fn)?
            .apply(&self.fc2)
    }
}

// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/modeling_siglip.py#L614
#[derive(Debug, Clone)]
struct EncoderLayer {
    self_attn: Attention,
    layer_norm1: LayerNorm,
    mlp: Mlp,
    layer_norm2: LayerNorm,
}

impl EncoderLayer {
    fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
        let hidden_size = cfg.hidden_size();
        let layer_norm_eps = cfg.layer_norm_eps();
        let self_attn = Attention::new(cfg, vb.pp("self_attn"))?;
        let layer_norm1 = layer_norm(hidden_size, layer_norm_eps, vb.pp("layer_norm1"))?;
        let mlp = Mlp::new(cfg, vb.pp("mlp"))?;
        let layer_norm2 = layer_norm(hidden_size, layer_norm_eps, vb.pp("layer_norm2"))?;
        Ok(Self {
            self_attn,
            layer_norm1,
            mlp,
            layer_norm2,
        })
    }

    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
        let residual = xs;
        let xs = xs.apply(&self.layer_norm1)?;
        let xs = self.self_attn.forward(&xs, attention_mask)?;
        let xs = (residual + xs)?;
        let residual = &xs;
        let xs = xs.apply(&self.layer_norm2)?.apply(&self.mlp)?;
        let xs = (xs + residual)?;
        Ok(xs)
    }
}

#[derive(Debug, Clone)]
struct Encoder {
    layers: Vec<EncoderLayer>,
}

impl Encoder {
    fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
        let mut layers = vec![];
        let vb = vb.pp("layers");
        for layer_idx in 0..cfg.num_hidden_layers() {
            let layer = EncoderLayer::new(cfg, vb.pp(layer_idx))?;
            layers.push(layer)
        }
        Ok(Self { layers })
    }

    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
        let mut xs = xs.clone();
        for layer in self.layers.iter() {
            xs = layer.forward(&xs, attention_mask)?
        }
        Ok(xs)
    }
}

#[derive(Debug, Clone)]
struct VisionEmbeddings {
    patch_embedding: candle_nn::Conv2d,
    position_embedding: candle_nn::Embedding,
    position_ids: Tensor,
}

impl VisionEmbeddings {
    fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
        let conv2d_cfg = candle_nn::Conv2dConfig {
            stride: cfg.patch_size,
            ..Default::default()
        };
        let patch_embedding = candle_nn::conv2d(
            cfg.num_channels,
            cfg.hidden_size,
            cfg.patch_size,
            conv2d_cfg,
            vb.pp("patch_embedding"),
        )?;
        let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
        let position_ids = Tensor::arange(0, num_patches as i64, vb.device())?;
        let position_embedding =
            candle_nn::embedding(num_patches, cfg.hidden_size(), vb.pp("position_embedding"))?;
        Ok(Self {
            patch_embedding,
            position_embedding,
            position_ids,
        })
    }
}

impl Module for VisionEmbeddings {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        let (_batch, _channels, _height, _width) = xs.dims4()?;
        let embeddings = xs.apply(&self.patch_embedding)?;
        let embeddings = embeddings.flatten_from(2)?.transpose(1, 2)?;
        let position_embedding = self.position_embedding.forward(&self.position_ids)?;
        embeddings.broadcast_add(&position_embedding)
    }
}

#[derive(Debug, Clone)]
struct VisionTransformer {
    embeddings: VisionEmbeddings,
    encoder: Encoder,
    post_layernorm: LayerNorm,
    head: Option<MultiheadAttentionPoolingHead>,
}

impl VisionTransformer {
    fn new(cfg: &VisionConfig, use_head: bool, vb: VarBuilder) -> Result<Self> {
        let embeddings = VisionEmbeddings::new(cfg, vb.pp("embeddings"))?;
        let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
        let post_layernorm =
            layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("post_layernorm"))?;
        let head = if use_head {
            Some(MultiheadAttentionPoolingHead::new(cfg, vb.pp("head"))?)
        } else {
            None
        };
        Ok(Self {
            embeddings,
            encoder,
            post_layernorm,
            head,
        })
    }
}

impl Module for VisionTransformer {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        let xs = xs.apply(&self.embeddings)?;
        let xs = self.encoder.forward(&xs, None)?;
        let xs = xs.apply(&self.post_layernorm)?;
        match self.head.as_ref() {
            None => Ok(xs),
            Some(h) => xs.apply(h),
        }
    }
}

#[derive(Debug, Clone)]
pub struct VisionModel {
    vision_model: VisionTransformer,
}

impl VisionModel {
    pub fn new(cfg: &VisionConfig, use_head: bool, vb: VarBuilder) -> Result<Self> {
        let vision_model = VisionTransformer::new(cfg, use_head, vb)?;
        Ok(Self { vision_model })
    }
}

impl Module for VisionModel {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        xs.apply(&self.vision_model)
    }
}

#[derive(Debug, Clone)]
struct TextEmbeddings {
    token_embedding: candle_nn::Embedding,
    position_embedding: candle_nn::Embedding,
    position_ids: Tensor,
}

impl TextEmbeddings {
    fn new(cfg: &TextConfig, vb: VarBuilder) -> Result<Self> {
        let token_embedding =
            candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("token_embedding"))?;
        let position_embedding = candle_nn::embedding(
            cfg.max_position_embeddings,
            cfg.hidden_size,
            vb.pp("position_embedding"),
        )?;
        let position_ids =
            Tensor::arange(0u32, cfg.max_position_embeddings as u32, vb.device())?.unsqueeze(0)?;
        Ok(Self {
            token_embedding,
            position_embedding,
            position_ids,
        })
    }
}

impl Module for TextEmbeddings {
    fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
        let seq_length = input_ids.dim(D::Minus1)?;
        let inputs_embeds = self.token_embedding.forward(input_ids)?;
        let position_ids = self.position_ids.narrow(1, 0, seq_length)?;
        let position_embedding = self.position_embedding.forward(&position_ids)?;
        inputs_embeds.broadcast_add(&position_embedding)
    }
}

#[derive(Debug, Clone)]
pub struct TextTransformer {
    embeddings: TextEmbeddings,
    encoder: Encoder,
    final_layer_norm: LayerNorm,
    pub head: Linear,
}

impl TextTransformer {
    fn new(cfg: &TextConfig, vb: VarBuilder) -> Result<Self> {
        let embeddings = TextEmbeddings::new(cfg, vb.pp("embeddings"))?;
        let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
        let final_layer_norm = layer_norm(
            cfg.hidden_size,
            cfg.layer_norm_eps,
            vb.pp("final_layer_norm"),
        )?;
        let head = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("head"))?;
        Ok(Self {
            embeddings,
            encoder,
            final_layer_norm,
            head,
        })
    }
}
impl Module for TextTransformer {
    fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
        let (_bsz, seq_len) = input_ids.dims2()?;
        let input_ids = self.embeddings.forward(input_ids)?;
        let input_ids = self.encoder.forward(&input_ids, None)?;
        let last_hidden_state = self.final_layer_norm.forward(&input_ids)?;
        last_hidden_state
            .i((.., seq_len - 1, ..))?
            .contiguous()?
            .apply(&self.head)
    }
}

#[derive(Debug, Clone)]
pub struct TextModel {
    pub text_model: TextTransformer,
}

impl TextModel {
    pub fn new(cfg: &TextConfig, vb: VarBuilder) -> Result<Self> {
        let text_model = TextTransformer::new(cfg, vb)?;
        Ok(Self { text_model })
    }
}

impl Module for TextModel {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        xs.apply(&self.text_model)
    }
}

#[derive(Clone, Debug)]
pub struct Model {
    text_model: TextModel,
    vision_model: VisionModel,
    logit_bias: Tensor,
    logit_scale: Tensor,
}

impl Model {
    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
        let text_model = TextModel::new(&cfg.text_config, vb.pp("text_model"))?;
        let vision_model = VisionModel::new(&cfg.vision_config, true, vb.pp("vision_model"))?;
        let logit_scale = vb.get(&[1], "logit_scale")?;
        let logit_bias = vb.get(&[1], "logit_bias")?;
        Ok(Self {
            text_model,
            vision_model,
            logit_bias,
            logit_scale,
        })
    }

    pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
        input_ids.apply(&self.text_model)
    }

    pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
        pixel_values.apply(&self.vision_model)
    }

    pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> {
        let image_features = self.get_image_features(pixel_values)?;
        let text_features = self.get_text_features(input_ids)?;
        let image_features_normalized = div_l2_norm(&image_features)?;
        let text_features_normalized = div_l2_norm(&text_features)?;
        let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?;
        let logit_scale = self.logit_scale.exp()?;
        let logits_per_text = logits_per_text
            .broadcast_mul(&logit_scale)?
            .broadcast_add(&self.logit_bias)?;
        let logits_per_image = logits_per_text.t()?;
        Ok((logits_per_text, logits_per_image))
    }
}