summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/pixtral/llava.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/pixtral/llava.rs')
-rw-r--r--candle-transformers/src/models/pixtral/llava.rs72
1 files changed, 72 insertions, 0 deletions
diff --git a/candle-transformers/src/models/pixtral/llava.rs b/candle-transformers/src/models/pixtral/llava.rs
new file mode 100644
index 00000000..33e0aca0
--- /dev/null
+++ b/candle-transformers/src/models/pixtral/llava.rs
@@ -0,0 +1,72 @@
+use candle::{Module, Result, Tensor};
+use candle_nn::{linear, Linear, VarBuilder};
+
+use super::vision_model;
+use crate::models::mistral;
+
+#[derive(serde::Deserialize, Debug, Clone)]
+pub struct Config {
+ pub projector_hidden_act: candle_nn::Activation,
+ pub text_config: mistral::Config,
+ pub vision_config: vision_model::Config,
+ pub image_token_index: usize,
+ pub image_seq_length: usize,
+}
+
+#[derive(Debug, Clone)]
+pub struct MultiModalProjector {
+ linear_1: Linear,
+ act: candle_nn::Activation,
+ linear_2: Linear,
+}
+
+impl MultiModalProjector {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let (hidden_v, hidden_t) = (cfg.vision_config.hidden_size, cfg.text_config.hidden_size);
+ let linear_1 = linear(hidden_v, hidden_t, vb.pp("linear_1"))?;
+ let linear_2 = linear(hidden_t, hidden_t, vb.pp("linear_2"))?;
+ Ok(Self {
+ linear_1,
+ act: cfg.projector_hidden_act,
+ linear_2,
+ })
+ }
+}
+
+impl Module for MultiModalProjector {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.apply(&self.linear_1)?
+ .apply(&self.act)?
+ .apply(&self.linear_2)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct Model {
+ pub multi_modal_projector: MultiModalProjector,
+ pub language_model: mistral::Model,
+ pub vision_tower: vision_model::Model,
+ pub patch_size: usize,
+ pub dtype: candle::DType,
+}
+
+impl Model {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let language_model = mistral::Model::new(&cfg.text_config, vb.pp("language_model"))?;
+ let vision_tower = vision_model::Model::new(
+ &cfg.vision_config,
+ vb.pp("vision_tower").to_dtype(candle::DType::F32),
+ )?;
+ let multi_modal_projector = MultiModalProjector::new(
+ cfg,
+ vb.pp("multi_modal_projector").to_dtype(candle::DType::F32),
+ )?;
+ Ok(Self {
+ multi_modal_projector,
+ language_model,
+ vision_tower,
+ patch_size: cfg.vision_config.patch_size,
+ dtype: vb.dtype(),
+ })
+ }
+}