summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-05 22:42:20 +0100
committerGitHub <noreply@github.com>2023-10-05 22:42:20 +0100
commit4631c48273330753c25cf75616ade152915301a2 (patch)
treeea0eb33e48a6f83fb4fae2df60f3bfefebafa747 /candle-transformers
parent716883e9b0c3cb6a481bd3b38507f1cdeca3c642 (diff)
downloadcandle-4631c48273330753c25cf75616ade152915301a2.tar.gz
candle-4631c48273330753c25cf75616ade152915301a2.tar.bz2
candle-4631c48273330753c25cf75616ade152915301a2.zip
Remove some todos. (#1042)
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/stable_diffusion/attention.rs8
1 files changed, 4 insertions, 4 deletions
diff --git a/candle-transformers/src/models/stable_diffusion/attention.rs b/candle-transformers/src/models/stable_diffusion/attention.rs
index b3ea91f9..07ce0fe4 100644
--- a/candle-transformers/src/models/stable_diffusion/attention.rs
+++ b/candle-transformers/src/models/stable_diffusion/attention.rs
@@ -527,10 +527,10 @@ impl Module for AttentionBlock {
.transpose_for_scores(value_proj)?
.to_dtype(DType::F32)?;
- let scale = f64::powf((self.channels as f64) / (self.num_heads as f64), -0.25);
- let attention_scores =
- // TODO: Check that this needs two multiplication by `scale`.
- (query_states * scale)?.matmul(&(key_states.t()? * scale)?)?;
+ // scale is applied twice, hence the -0.25 here rather than -0.5.
+ // https://github.com/huggingface/diffusers/blob/d3d22ce5a894becb951eec03e663951b28d45135/src/diffusers/models/attention.py#L87
+ let scale = f64::powf(self.channels as f64 / self.num_heads as f64, -0.25);
+ let attention_scores = (query_states * scale)?.matmul(&(key_states.t()? * scale)?)?;
let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?;
let xs = attention_probs.matmul(&value_states.contiguous()?)?;