summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/stable_diffusion/clip.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/stable_diffusion/clip.rs')
-rw-r--r--candle-transformers/src/models/stable_diffusion/clip.rs389
1 files changed, 389 insertions, 0 deletions
diff --git a/candle-transformers/src/models/stable_diffusion/clip.rs b/candle-transformers/src/models/stable_diffusion/clip.rs
new file mode 100644
index 00000000..e7a20270
--- /dev/null
+++ b/candle-transformers/src/models/stable_diffusion/clip.rs
@@ -0,0 +1,389 @@
+//! Contrastive Language-Image Pre-Training
+//!
+//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
+//! pairs of images with related texts.
+//!
+//! https://github.com/openai/CLIP
+use candle::{DType, Device, Result, Tensor, D};
+use candle_nn as nn;
+use candle_nn::Module;
+
+#[derive(Debug, Clone, Copy)]
+pub enum Activation {
+ QuickGelu,
+ Gelu,
+ GeluErf,
+}
+
+impl Module for Activation {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ match self {
+ Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
+ Activation::Gelu => xs.gelu(),
+ Activation::GeluErf => xs.gelu_erf(),
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct Config {
+ vocab_size: usize,
+ embed_dim: usize, // aka config.hidden_size
+ activation: Activation, // aka config.hidden_act
+ intermediate_size: usize,
+ pub max_position_embeddings: usize,
+ // The character to use for padding, use EOS when not set.
+ pub pad_with: Option<String>,
+ num_hidden_layers: usize,
+ num_attention_heads: usize,
+ #[allow(dead_code)]
+ projection_dim: usize,
+}
+
+impl Config {
+ // The config details can be found in the "text_config" section of this json file:
+ // https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
+ pub fn v1_5() -> Self {
+ Self {
+ vocab_size: 49408,
+ embed_dim: 768,
+ intermediate_size: 3072,
+ max_position_embeddings: 77,
+ pad_with: None,
+ num_hidden_layers: 12,
+ num_attention_heads: 12,
+ projection_dim: 768,
+ activation: Activation::QuickGelu,
+ }
+ }
+
+ // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/text_encoder/config.json
+ pub fn v2_1() -> Self {
+ Self {
+ vocab_size: 49408,
+ embed_dim: 1024,
+ intermediate_size: 4096,
+ max_position_embeddings: 77,
+ pad_with: Some("!".to_string()),
+ num_hidden_layers: 23,
+ num_attention_heads: 16,
+ projection_dim: 512,
+ activation: Activation::Gelu,
+ }
+ }
+
+ // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/text_encoder/config.json
+ pub fn sdxl() -> Self {
+ Self {
+ vocab_size: 49408,
+ embed_dim: 768,
+ intermediate_size: 3072,
+ max_position_embeddings: 77,
+ pad_with: Some("!".to_string()),
+ num_hidden_layers: 12,
+ num_attention_heads: 12,
+ projection_dim: 768,
+ activation: Activation::QuickGelu,
+ }
+ }
+
+ // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/text_encoder_2/config.json
+ pub fn sdxl2() -> Self {
+ Self {
+ vocab_size: 49408,
+ embed_dim: 1280,
+ intermediate_size: 5120,
+ max_position_embeddings: 77,
+ pad_with: Some("!".to_string()),
+ num_hidden_layers: 32,
+ num_attention_heads: 20,
+ projection_dim: 1280,
+ activation: Activation::Gelu,
+ }
+ }
+
+ // https://huggingface.co/warp-ai/wuerstchen/blob/main/text_encoder/config.json
+ pub fn wuerstchen() -> Self {
+ Self {
+ vocab_size: 49408,
+ embed_dim: 1024,
+ intermediate_size: 4096,
+ max_position_embeddings: 77,
+ pad_with: None,
+ num_hidden_layers: 24,
+ num_attention_heads: 16,
+ projection_dim: 1024,
+ activation: Activation::GeluErf,
+ }
+ }
+
+ // https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/text_encoder/config.json
+ pub fn wuerstchen_prior() -> Self {
+ Self {
+ vocab_size: 49408,
+ embed_dim: 1280,
+ intermediate_size: 5120,
+ max_position_embeddings: 77,
+ pad_with: None,
+ num_hidden_layers: 32,
+ num_attention_heads: 20,
+ projection_dim: 512,
+ activation: Activation::GeluErf,
+ }
+ }
+}
+
+// CLIP Text Model
+// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py
+#[derive(Debug)]
+struct ClipTextEmbeddings {
+ token_embedding: candle_nn::Embedding,
+ position_embedding: candle_nn::Embedding,
+ position_ids: Tensor,
+}
+
+impl ClipTextEmbeddings {
+ fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let token_embedding =
+ candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
+ let position_embedding = candle_nn::embedding(
+ c.max_position_embeddings,
+ c.embed_dim,
+ vs.pp("position_embedding"),
+ )?;
+ let position_ids =
+ Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
+ Ok(ClipTextEmbeddings {
+ token_embedding,
+ position_embedding,
+ position_ids,
+ })
+ }
+}
+
+impl Module for ClipTextEmbeddings {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let token_embedding = self.token_embedding.forward(xs)?;
+ let position_embedding = self.position_embedding.forward(&self.position_ids)?;
+ token_embedding.broadcast_add(&position_embedding)
+ }
+}
+
+#[derive(Debug)]
+struct ClipAttention {
+ k_proj: candle_nn::Linear,
+ v_proj: candle_nn::Linear,
+ q_proj: candle_nn::Linear,
+ out_proj: candle_nn::Linear,
+ head_dim: usize,
+ scale: f64,
+ num_attention_heads: usize,
+}
+
+impl ClipAttention {
+ fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let embed_dim = c.embed_dim;
+ let num_attention_heads = c.num_attention_heads;
+ let k_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("k_proj"))?;
+ let v_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("v_proj"))?;
+ let q_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("q_proj"))?;
+ let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?;
+ let head_dim = embed_dim / num_attention_heads;
+ let scale = (head_dim as f64).powf(-0.5);
+ Ok(ClipAttention {
+ k_proj,
+ v_proj,
+ q_proj,
+ out_proj,
+ head_dim,
+ scale,
+ num_attention_heads,
+ })
+ }
+
+ fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
+ xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
+ .transpose(1, 2)?
+ .contiguous()
+ }
+
+ fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
+ let in_dtype = xs.dtype();
+ let (bsz, seq_len, embed_dim) = xs.dims3()?;
+ let query_states = (self.q_proj.forward(xs)? * self.scale)?;
+ let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
+ let query_states = self
+ .shape(&query_states, seq_len, bsz)?
+ .reshape(proj_shape)?
+ .to_dtype(DType::F32)?;
+ let key_states = self
+ .shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
+ .reshape(proj_shape)?
+ .to_dtype(DType::F32)?;
+ let value_states = self
+ .shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
+ .reshape(proj_shape)?
+ .to_dtype(DType::F32)?;
+ let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
+
+ let src_len = key_states.dim(1)?;
+ let attn_weights = attn_weights
+ .reshape((bsz, self.num_attention_heads, seq_len, src_len))?
+ .broadcast_add(causal_attention_mask)?;
+ let attn_weights =
+ attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
+ let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
+
+ let attn_output = attn_weights.matmul(&value_states)?.to_dtype(in_dtype)?;
+ let attn_output = attn_output
+ .reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
+ .transpose(1, 2)?
+ .reshape((bsz, seq_len, embed_dim))?;
+ self.out_proj.forward(&attn_output)
+ }
+}
+
+#[derive(Debug)]
+struct ClipMlp {
+ fc1: candle_nn::Linear,
+ fc2: candle_nn::Linear,
+ activation: Activation,
+}
+
+impl ClipMlp {
+ fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let fc1 = candle_nn::linear(c.embed_dim, c.intermediate_size, vs.pp("fc1"))?;
+ let fc2 = candle_nn::linear(c.intermediate_size, c.embed_dim, vs.pp("fc2"))?;
+ Ok(ClipMlp {
+ fc1,
+ fc2,
+ activation: c.activation,
+ })
+ }
+}
+
+impl ClipMlp {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = self.fc1.forward(xs)?;
+ self.fc2.forward(&self.activation.forward(&xs)?)
+ }
+}
+
+#[derive(Debug)]
+struct ClipEncoderLayer {
+ self_attn: ClipAttention,
+ layer_norm1: candle_nn::LayerNorm,
+ mlp: ClipMlp,
+ layer_norm2: candle_nn::LayerNorm,
+}
+
+impl ClipEncoderLayer {
+ fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?;
+ let layer_norm1 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm1"))?;
+ let mlp = ClipMlp::new(vs.pp("mlp"), c)?;
+ let layer_norm2 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm2"))?;
+ Ok(ClipEncoderLayer {
+ self_attn,
+ layer_norm1,
+ mlp,
+ layer_norm2,
+ })
+ }
+
+ fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
+ let residual = xs;
+ let xs = self.layer_norm1.forward(xs)?;
+ let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
+ let xs = (xs + residual)?;
+
+ let residual = &xs;
+ let xs = self.layer_norm2.forward(&xs)?;
+ let xs = self.mlp.forward(&xs)?;
+ xs + residual
+ }
+}
+
+#[derive(Debug)]
+struct ClipEncoder {
+ layers: Vec<ClipEncoderLayer>,
+}
+
+impl ClipEncoder {
+ fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let vs = vs.pp("layers");
+ let mut layers: Vec<ClipEncoderLayer> = Vec::new();
+ for index in 0..c.num_hidden_layers {
+ let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?;
+ layers.push(layer)
+ }
+ Ok(ClipEncoder { layers })
+ }
+
+ fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
+ let mut xs = xs.clone();
+ for layer in self.layers.iter() {
+ xs = layer.forward(&xs, causal_attention_mask)?;
+ }
+ Ok(xs)
+ }
+}
+
+/// A CLIP transformer based model.
+#[derive(Debug)]
+pub struct ClipTextTransformer {
+ embeddings: ClipTextEmbeddings,
+ encoder: ClipEncoder,
+ final_layer_norm: candle_nn::LayerNorm,
+}
+
+impl ClipTextTransformer {
+ pub fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let vs = vs.pp("text_model");
+ let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?;
+ let encoder = ClipEncoder::new(vs.pp("encoder"), c)?;
+ let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?;
+ Ok(ClipTextTransformer {
+ embeddings,
+ encoder,
+ final_layer_norm,
+ })
+ }
+
+ // https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678
+ fn build_causal_attention_mask(
+ bsz: usize,
+ seq_len: usize,
+ mask_after: usize,
+ device: &Device,
+ ) -> Result<Tensor> {
+ let mask: Vec<_> = (0..seq_len)
+ .flat_map(|i| {
+ (0..seq_len).map(move |j| {
+ if j > i || j > mask_after {
+ f32::MIN
+ } else {
+ 0.
+ }
+ })
+ })
+ .collect();
+ let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
+ mask.broadcast_as((bsz, seq_len, seq_len))
+ }
+
+ pub fn forward_with_mask(&self, xs: &Tensor, mask_after: usize) -> Result<Tensor> {
+ let (bsz, seq_len) = xs.dims2()?;
+ let xs = self.embeddings.forward(xs)?;
+ let causal_attention_mask =
+ Self::build_causal_attention_mask(bsz, seq_len, mask_after, xs.device())?;
+ let xs = self.encoder.forward(&xs, &causal_attention_mask)?;
+ self.final_layer_norm.forward(&xs)
+ }
+}
+
+impl Module for ClipTextTransformer {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ self.forward_with_mask(xs, usize::MAX)
+ }
+}