diff options
author | Czxck001 <10724409+Czxck001@users.noreply.github.com> | 2024-10-13 13:08:40 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-13 22:08:40 +0200 |
commit | ca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e (patch) | |
tree | 8f61fd8b9a4c86b08e50328d051e0acec3945fb3 /candle-transformers/src/models/stable_diffusion/clip.rs | |
parent | 0d96ec31e8be03f844ed0aed636d6217dee9c7bc (diff) | |
download | candle-ca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e.tar.gz candle-ca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e.tar.bz2 candle-ca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e.zip |
Add Stable Diffusion 3 Example (#2558)
* Add stable diffusion 3 example
Add get_qkv_linear to handle different dimensionality in linears
Add stable diffusion 3 example
Add use_quant_conv and use_post_quant_conv for vae in stable diffusion
adapt existing AutoEncoderKLConfig to the change
add forward_until_encoder_layer to ClipTextTransformer
rename sd3 config to sd3_medium in mmdit; minor clean-up
Enable flash-attn for mmdit impl when the feature is enabled.
Add sd3 example codebase
add document
crediting references
pass the cargo fmt test
pass the clippy test
* fix typos
* expose cfg_scale and time_shift as options
* Replace the sample image with JPG version. Change image output format accordingly.
* make meaningful error messages
* remove the tail-end assignment in sd3_vae_vb_rename
* remove the CUDA requirement
* use default_value in clap args
* add use_flash_attn to turn on/off flash-attn for MMDiT at runtime
* resolve clippy errors and warnings
* use default_value_t
* Pin the web-sys dependency.
* Clippy fix.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-transformers/src/models/stable_diffusion/clip.rs')
-rw-r--r-- | candle-transformers/src/models/stable_diffusion/clip.rs | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/candle-transformers/src/models/stable_diffusion/clip.rs b/candle-transformers/src/models/stable_diffusion/clip.rs index 5254818e..2f631248 100644 --- a/candle-transformers/src/models/stable_diffusion/clip.rs +++ b/candle-transformers/src/models/stable_diffusion/clip.rs @@ -388,6 +388,37 @@ impl ClipTextTransformer { let xs = self.encoder.forward(&xs, &causal_attention_mask)?; self.final_layer_norm.forward(&xs) } + + pub fn forward_until_encoder_layer( + &self, + xs: &Tensor, + mask_after: usize, + until_layer: isize, + ) -> Result<(Tensor, Tensor)> { + let (bsz, seq_len) = xs.dims2()?; + let xs = self.embeddings.forward(xs)?; + let causal_attention_mask = + Self::build_causal_attention_mask(bsz, seq_len, mask_after, xs.device())?; + + let mut xs = xs.clone(); + let mut intermediate = xs.clone(); + + // Modified encoder.forward that returns the intermediate tensor along with final output. + let until_layer = if until_layer < 0 { + self.encoder.layers.len() as isize + until_layer + } else { + until_layer + } as usize; + + for (layer_id, layer) in self.encoder.layers.iter().enumerate() { + xs = layer.forward(&xs, &causal_attention_mask)?; + if layer_id == until_layer { + intermediate = xs.clone(); + } + } + + Ok((self.final_layer_norm.forward(&xs)?, intermediate)) + } } impl Module for ClipTextTransformer { |