summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-11-12 17:10:12 +0100
committerGitHub <noreply@github.com>2024-11-12 17:10:12 +0100
commit06350c31c780d6ea485f506032aea6ff8809e38a (patch)
tree811d3bd8443335b93f17d547d3abb5a9b3fc6ffd /candle-transformers
parent9453cc30958dd0e9209aaeba30b15bb97aff0ea9 (diff)
downloadcandle-06350c31c780d6ea485f506032aea6ff8809e38a.tar.gz
candle-06350c31c780d6ea485f506032aea6ff8809e38a.tar.bz2
candle-06350c31c780d6ea485f506032aea6ff8809e38a.zip
Add some missing index-select metal kernels. (#2613)
* Add some missing index-select metal kernels. * Make some matrix contiguous pre-matmul.
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/chinese_clip/mod.rs3
1 files changed, 2 insertions, 1 deletions
diff --git a/candle-transformers/src/models/chinese_clip/mod.rs b/candle-transformers/src/models/chinese_clip/mod.rs
index 88472f0b..0f6eedd0 100644
--- a/candle-transformers/src/models/chinese_clip/mod.rs
+++ b/candle-transformers/src/models/chinese_clip/mod.rs
@@ -171,7 +171,8 @@ impl ChineseClipModel {
) -> Result<Tensor> {
let output = self
.text_model
- .forward(input_ids, token_type_ids, attention_mask)?;
+ .forward(input_ids, token_type_ids, attention_mask)?
+ .contiguous()?;
self.text_projection.forward(&output)
}