diff options
Diffstat (limited to 'candle-transformers/src/models/pixtral/llava.rs')
-rw-r--r-- | candle-transformers/src/models/pixtral/llava.rs | 72 |
1 files changed, 72 insertions, 0 deletions
diff --git a/candle-transformers/src/models/pixtral/llava.rs b/candle-transformers/src/models/pixtral/llava.rs new file mode 100644 index 00000000..33e0aca0 --- /dev/null +++ b/candle-transformers/src/models/pixtral/llava.rs @@ -0,0 +1,72 @@ +use candle::{Module, Result, Tensor}; +use candle_nn::{linear, Linear, VarBuilder}; + +use super::vision_model; +use crate::models::mistral; + +#[derive(serde::Deserialize, Debug, Clone)] +pub struct Config { + pub projector_hidden_act: candle_nn::Activation, + pub text_config: mistral::Config, + pub vision_config: vision_model::Config, + pub image_token_index: usize, + pub image_seq_length: usize, +} + +#[derive(Debug, Clone)] +pub struct MultiModalProjector { + linear_1: Linear, + act: candle_nn::Activation, + linear_2: Linear, +} + +impl MultiModalProjector { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let (hidden_v, hidden_t) = (cfg.vision_config.hidden_size, cfg.text_config.hidden_size); + let linear_1 = linear(hidden_v, hidden_t, vb.pp("linear_1"))?; + let linear_2 = linear(hidden_t, hidden_t, vb.pp("linear_2"))?; + Ok(Self { + linear_1, + act: cfg.projector_hidden_act, + linear_2, + }) + } +} + +impl Module for MultiModalProjector { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + xs.apply(&self.linear_1)? + .apply(&self.act)? + .apply(&self.linear_2) + } +} + +#[derive(Debug, Clone)] +pub struct Model { + pub multi_modal_projector: MultiModalProjector, + pub language_model: mistral::Model, + pub vision_tower: vision_model::Model, + pub patch_size: usize, + pub dtype: candle::DType, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let language_model = mistral::Model::new(&cfg.text_config, vb.pp("language_model"))?; + let vision_tower = vision_model::Model::new( + &cfg.vision_config, + vb.pp("vision_tower").to_dtype(candle::DType::F32), + )?; + let multi_modal_projector = MultiModalProjector::new( + cfg, + vb.pp("multi_modal_projector").to_dtype(candle::DType::F32), + )?; + Ok(Self { + multi_modal_projector, + language_model, + vision_tower, + patch_size: cfg.vision_config.patch_size, + dtype: vb.dtype(), + }) + } +} |