diff options
Diffstat (limited to 'candle-examples/examples/stable-diffusion/clip.rs')
-rw-r--r-- | candle-examples/examples/stable-diffusion/clip.rs | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/candle-examples/examples/stable-diffusion/clip.rs b/candle-examples/examples/stable-diffusion/clip.rs index 227660b1..ac9843f7 100644 --- a/candle-examples/examples/stable-diffusion/clip.rs +++ b/candle-examples/examples/stable-diffusion/clip.rs @@ -103,7 +103,7 @@ impl ClipTextEmbeddings { fn forward(&self, xs: &Tensor) -> Result<Tensor> { let token_embedding = self.token_embedding.forward(xs)?; let position_embedding = self.position_embedding.forward(&self.position_ids)?; - token_embedding + position_embedding + token_embedding.broadcast_add(&position_embedding) } } @@ -161,9 +161,9 @@ impl ClipAttention { let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?; let src_len = key_states.dim(1)?; - let attn_weights = - (attn_weights.reshape((bsz, self.num_attention_heads, seq_len, src_len))? - + causal_attention_mask)?; + let attn_weights = attn_weights + .reshape((bsz, self.num_attention_heads, seq_len, src_len))? + .broadcast_add(causal_attention_mask)?; let attn_weights = attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?; let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; @@ -287,7 +287,7 @@ impl ClipTextTransformer { // https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678 fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Device) -> Result<Tensor> { let mask: Vec<_> = (0..seq_len) - .flat_map(|i| (0..seq_len).map(move |j| u8::from(j > i))) + .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::MIN } else { 0. })) .collect(); let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?; mask.broadcast_as((bsz, seq_len, seq_len)) |