diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-05 22:42:20 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-05 22:42:20 +0100 |
commit | 4631c48273330753c25cf75616ade152915301a2 (patch) | |
tree | ea0eb33e48a6f83fb4fae2df60f3bfefebafa747 /candle-transformers | |
parent | 716883e9b0c3cb6a481bd3b38507f1cdeca3c642 (diff) | |
download | candle-4631c48273330753c25cf75616ade152915301a2.tar.gz candle-4631c48273330753c25cf75616ade152915301a2.tar.bz2 candle-4631c48273330753c25cf75616ade152915301a2.zip |
Remove some todos. (#1042)
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/stable_diffusion/attention.rs | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/candle-transformers/src/models/stable_diffusion/attention.rs b/candle-transformers/src/models/stable_diffusion/attention.rs index b3ea91f9..07ce0fe4 100644 --- a/candle-transformers/src/models/stable_diffusion/attention.rs +++ b/candle-transformers/src/models/stable_diffusion/attention.rs @@ -527,10 +527,10 @@ impl Module for AttentionBlock { .transpose_for_scores(value_proj)? .to_dtype(DType::F32)?; - let scale = f64::powf((self.channels as f64) / (self.num_heads as f64), -0.25); - let attention_scores = - // TODO: Check that this needs two multiplication by `scale`. - (query_states * scale)?.matmul(&(key_states.t()? * scale)?)?; + // scale is applied twice, hence the -0.25 here rather than -0.5. + // https://github.com/huggingface/diffusers/blob/d3d22ce5a894becb951eec03e663951b28d45135/src/diffusers/models/attention.py#L87 + let scale = f64::powf(self.channels as f64 / self.num_heads as f64, -0.25); + let attention_scores = (query_states * scale)?.matmul(&(key_states.t()? * scale)?)?; let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?; let xs = attention_probs.matmul(&value_states.contiguous()?)?; |