diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-02 14:36:28 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-02 14:36:28 +0200 |
commit | b23436bf90b99eb17aed36aaa219875d3c962a7e (patch) | |
tree | 8ee1aabc5bbf981ed66511fa7df7e2b26079b40f /candle-transformers | |
parent | be9c200cbb16b59fe1f1e8c0f606981412c9b757 (diff) | |
download | candle-b23436bf90b99eb17aed36aaa219875d3c962a7e.tar.gz candle-b23436bf90b99eb17aed36aaa219875d3c962a7e.tar.bz2 candle-b23436bf90b99eb17aed36aaa219875d3c962a7e.zip |
Stable diffusion fix. (#1993)
* Stable diffusion fix.
* And add a comment.
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/stable_diffusion/attention.rs | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/candle-transformers/src/models/stable_diffusion/attention.rs b/candle-transformers/src/models/stable_diffusion/attention.rs index 07ce0fe4..05e51e44 100644 --- a/candle-transformers/src/models/stable_diffusion/attention.rs +++ b/candle-transformers/src/models/stable_diffusion/attention.rs @@ -533,7 +533,9 @@ impl Module for AttentionBlock { let attention_scores = (query_states * scale)?.matmul(&(key_states.t()? * scale)?)?; let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?; - let xs = attention_probs.matmul(&value_states.contiguous()?)?; + // TODO: revert the call to force_contiguous once the three matmul kernels have been + // adapted to handle layout with some dims set to 1. + let xs = attention_probs.matmul(&value_states.force_contiguous()?)?; let xs = xs.to_dtype(in_dtype)?; let xs = xs.transpose(1, 2)?.contiguous()?; let xs = xs.flatten_from(D::Minus2)?; |