summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/stable_diffusion/clip.rs
diff options
context:
space:
mode:
authorCzxck001 <10724409+Czxck001@users.noreply.github.com>2024-10-13 13:08:40 -0700
committerGitHub <noreply@github.com>2024-10-13 22:08:40 +0200
commitca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e (patch)
tree8f61fd8b9a4c86b08e50328d051e0acec3945fb3 /candle-transformers/src/models/stable_diffusion/clip.rs
parent0d96ec31e8be03f844ed0aed636d6217dee9c7bc (diff)
downloadcandle-ca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e.tar.gz
candle-ca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e.tar.bz2
candle-ca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e.zip
Add Stable Diffusion 3 Example (#2558)
* Add stable diffusion 3 example Add get_qkv_linear to handle different dimensionality in linears Add stable diffusion 3 example Add use_quant_conv and use_post_quant_conv for vae in stable diffusion adapt existing AutoEncoderKLConfig to the change add forward_until_encoder_layer to ClipTextTransformer rename sd3 config to sd3_medium in mmdit; minor clean-up Enable flash-attn for mmdit impl when the feature is enabled. Add sd3 example codebase add document crediting references pass the cargo fmt test pass the clippy test * fix typos * expose cfg_scale and time_shift as options * Replace the sample image with JPG version. Change image output format accordingly. * make meaningful error messages * remove the tail-end assignment in sd3_vae_vb_rename * remove the CUDA requirement * use default_value in clap args * add use_flash_attn to turn on/off flash-attn for MMDiT at runtime * resolve clippy errors and warnings * use default_value_t * Pin the web-sys dependency. * Clippy fix. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-transformers/src/models/stable_diffusion/clip.rs')
-rw-r--r--candle-transformers/src/models/stable_diffusion/clip.rs31
1 files changed, 31 insertions, 0 deletions
diff --git a/candle-transformers/src/models/stable_diffusion/clip.rs b/candle-transformers/src/models/stable_diffusion/clip.rs
index 5254818e..2f631248 100644
--- a/candle-transformers/src/models/stable_diffusion/clip.rs
+++ b/candle-transformers/src/models/stable_diffusion/clip.rs
@@ -388,6 +388,37 @@ impl ClipTextTransformer {
let xs = self.encoder.forward(&xs, &causal_attention_mask)?;
self.final_layer_norm.forward(&xs)
}
+
+ pub fn forward_until_encoder_layer(
+ &self,
+ xs: &Tensor,
+ mask_after: usize,
+ until_layer: isize,
+ ) -> Result<(Tensor, Tensor)> {
+ let (bsz, seq_len) = xs.dims2()?;
+ let xs = self.embeddings.forward(xs)?;
+ let causal_attention_mask =
+ Self::build_causal_attention_mask(bsz, seq_len, mask_after, xs.device())?;
+
+ let mut xs = xs.clone();
+ let mut intermediate = xs.clone();
+
+ // Modified encoder.forward that returns the intermediate tensor along with final output.
+ let until_layer = if until_layer < 0 {
+ self.encoder.layers.len() as isize + until_layer
+ } else {
+ until_layer
+ } as usize;
+
+ for (layer_id, layer) in self.encoder.layers.iter().enumerate() {
+ xs = layer.forward(&xs, &causal_attention_mask)?;
+ if layer_id == until_layer {
+ intermediate = xs.clone();
+ }
+ }
+
+ Ok((self.final_layer_norm.forward(&xs)?, intermediate))
+ }
}
impl Module for ClipTextTransformer {