summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-02 14:36:28 +0200
committerGitHub <noreply@github.com>2024-04-02 14:36:28 +0200
commitb23436bf90b99eb17aed36aaa219875d3c962a7e (patch)
tree8ee1aabc5bbf981ed66511fa7df7e2b26079b40f /candle-transformers
parentbe9c200cbb16b59fe1f1e8c0f606981412c9b757 (diff)
downloadcandle-b23436bf90b99eb17aed36aaa219875d3c962a7e.tar.gz
candle-b23436bf90b99eb17aed36aaa219875d3c962a7e.tar.bz2
candle-b23436bf90b99eb17aed36aaa219875d3c962a7e.zip
Stable diffusion fix. (#1993)
* Stable diffusion fix. * And add a comment.
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/stable_diffusion/attention.rs4
1 files changed, 3 insertions, 1 deletions
diff --git a/candle-transformers/src/models/stable_diffusion/attention.rs b/candle-transformers/src/models/stable_diffusion/attention.rs
index 07ce0fe4..05e51e44 100644
--- a/candle-transformers/src/models/stable_diffusion/attention.rs
+++ b/candle-transformers/src/models/stable_diffusion/attention.rs
@@ -533,7 +533,9 @@ impl Module for AttentionBlock {
let attention_scores = (query_states * scale)?.matmul(&(key_states.t()? * scale)?)?;
let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?;
- let xs = attention_probs.matmul(&value_states.contiguous()?)?;
+ // TODO: revert the call to force_contiguous once the three matmul kernels have been
+ // adapted to handle layout with some dims set to 1.
+ let xs = attention_probs.matmul(&value_states.force_contiguous()?)?;
let xs = xs.to_dtype(in_dtype)?;
let xs = xs.transpose(1, 2)?.contiguous()?;
let xs = xs.flatten_from(D::Minus2)?;