summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stable-diffusion/clip.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-07 19:31:45 +0200
committerGitHub <noreply@github.com>2023-08-07 18:31:45 +0100
commitfc265d9dcfc13ee0b03f4a09537a9e7156b29231 (patch)
treec94485d92647d340cb6e2a1eebc2c3aad4692d60 /candle-examples/examples/stable-diffusion/clip.rs
parent2345b8ce3f8ebab6e04d6ea25f7c809efb037995 (diff)
downloadcandle-fc265d9dcfc13ee0b03f4a09537a9e7156b29231.tar.gz
candle-fc265d9dcfc13ee0b03f4a09537a9e7156b29231.tar.bz2
candle-fc265d9dcfc13ee0b03f4a09537a9e7156b29231.zip
Some CLIP fixes for stable diffusion. (#338)
* Some CLIP fixes for stable diffusion. * Add the avg-pool2d operation on cpu.
Diffstat (limited to 'candle-examples/examples/stable-diffusion/clip.rs')
-rw-r--r--candle-examples/examples/stable-diffusion/clip.rs10
1 files changed, 5 insertions, 5 deletions
diff --git a/candle-examples/examples/stable-diffusion/clip.rs b/candle-examples/examples/stable-diffusion/clip.rs
index 227660b1..ac9843f7 100644
--- a/candle-examples/examples/stable-diffusion/clip.rs
+++ b/candle-examples/examples/stable-diffusion/clip.rs
@@ -103,7 +103,7 @@ impl ClipTextEmbeddings {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let token_embedding = self.token_embedding.forward(xs)?;
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
- token_embedding + position_embedding
+ token_embedding.broadcast_add(&position_embedding)
}
}
@@ -161,9 +161,9 @@ impl ClipAttention {
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
let src_len = key_states.dim(1)?;
- let attn_weights =
- (attn_weights.reshape((bsz, self.num_attention_heads, seq_len, src_len))?
- + causal_attention_mask)?;
+ let attn_weights = attn_weights
+ .reshape((bsz, self.num_attention_heads, seq_len, src_len))?
+ .broadcast_add(causal_attention_mask)?;
let attn_weights =
attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
@@ -287,7 +287,7 @@ impl ClipTextTransformer {
// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678
fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..seq_len)
- .flat_map(|i| (0..seq_len).map(move |j| u8::from(j > i)))
+ .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::MIN } else { 0. }))
.collect();
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
mask.broadcast_as((bsz, seq_len, seq_len))