diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-07 19:31:45 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-07 18:31:45 +0100 |
commit | fc265d9dcfc13ee0b03f4a09537a9e7156b29231 (patch) | |
tree | c94485d92647d340cb6e2a1eebc2c3aad4692d60 /candle-examples/examples/stable-diffusion/clip.rs | |
parent | 2345b8ce3f8ebab6e04d6ea25f7c809efb037995 (diff) | |
download | candle-fc265d9dcfc13ee0b03f4a09537a9e7156b29231.tar.gz candle-fc265d9dcfc13ee0b03f4a09537a9e7156b29231.tar.bz2 candle-fc265d9dcfc13ee0b03f4a09537a9e7156b29231.zip |
Some CLIP fixes for stable diffusion. (#338)
* Some CLIP fixes for stable diffusion.
* Add the avg-pool2d operation on cpu.
Diffstat (limited to 'candle-examples/examples/stable-diffusion/clip.rs')
-rw-r--r-- | candle-examples/examples/stable-diffusion/clip.rs | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/candle-examples/examples/stable-diffusion/clip.rs b/candle-examples/examples/stable-diffusion/clip.rs index 227660b1..ac9843f7 100644 --- a/candle-examples/examples/stable-diffusion/clip.rs +++ b/candle-examples/examples/stable-diffusion/clip.rs @@ -103,7 +103,7 @@ impl ClipTextEmbeddings { fn forward(&self, xs: &Tensor) -> Result<Tensor> { let token_embedding = self.token_embedding.forward(xs)?; let position_embedding = self.position_embedding.forward(&self.position_ids)?; - token_embedding + position_embedding + token_embedding.broadcast_add(&position_embedding) } } @@ -161,9 +161,9 @@ impl ClipAttention { let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?; let src_len = key_states.dim(1)?; - let attn_weights = - (attn_weights.reshape((bsz, self.num_attention_heads, seq_len, src_len))? - + causal_attention_mask)?; + let attn_weights = attn_weights + .reshape((bsz, self.num_attention_heads, seq_len, src_len))? + .broadcast_add(causal_attention_mask)?; let attn_weights = attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?; let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; @@ -287,7 +287,7 @@ impl ClipTextTransformer { // https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678 fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Device) -> Result<Tensor> { let mask: Vec<_> = (0..seq_len) - .flat_map(|i| (0..seq_len).map(move |j| u8::from(j > i))) + .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::MIN } else { 0. })) .collect(); let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?; mask.broadcast_as((bsz, seq_len, seq_len)) |