summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/metal_backend/mod.rs11
-rw-r--r--candle-metal-kernels/src/indexing.metal4
-rw-r--r--candle-transformers/src/models/chinese_clip/mod.rs3
3 files changed, 16 insertions, 2 deletions
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs
index 34931c9d..de107a61 100644
--- a/candle-core/src/metal_backend/mod.rs
+++ b/candle-core/src/metal_backend/mod.rs
@@ -1237,7 +1237,7 @@ impl BackendStorage for MetalStorage {
let dst_el = ids_l.shape().elem_count();
let dtype = self.dtype;
let device = self.device();
- let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
+ let buffer = device.new_buffer(dst_el, dtype, "gather")?;
let name = match (ids.dtype, self.dtype) {
(DType::U32, DType::F32) => "gather_u32_f32",
(DType::U32, DType::F16) => "gather_u32_f16",
@@ -1324,14 +1324,23 @@ impl BackendStorage for MetalStorage {
let device = self.device();
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
let name = match (ids.dtype, self.dtype) {
+ (DType::U8, DType::U8) => "is_u8_u8",
+ (DType::U8, DType::U32) => "is_u8_u32",
+ (DType::U8, DType::I64) => "is_u8_i64",
(DType::U8, DType::BF16) => "is_u8_bf16",
(DType::U8, DType::F32) => "is_u8_f32",
(DType::U8, DType::F16) => "is_u8_f16",
+ (DType::U32, DType::U8) => "is_u32_u8",
+ (DType::U32, DType::U32) => "is_u32_u32",
+ (DType::U32, DType::I64) => "is_u32_i64",
(DType::U32, DType::F32) => "is_u32_f32",
(DType::U32, DType::F16) => "is_u32_f16",
(DType::U32, DType::BF16) => "is_u32_bf16",
+ (DType::I64, DType::U8) => "is_i64_u8",
+ (DType::I64, DType::U32) => "is_i64_u32",
+ (DType::I64, DType::I64) => "is_i64_i64",
(DType::I64, DType::F32) => "is_i64_f32",
(DType::I64, DType::F16) => "is_i64_f16",
(DType::I64, DType::BF16) => "is_i64_bf16",
diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal
index 9eee97ca..c14f2c1f 100644
--- a/candle-metal-kernels/src/indexing.metal
+++ b/candle-metal-kernels/src/indexing.metal
@@ -193,12 +193,16 @@ INDEX_OP(is_i64_f16, int64_t, half)
INDEX_OP(is_i64_bf16, int64_t, bfloat)
#endif
+INDEX_OP(is_u32_u8, uint32_t, uint8_t)
+INDEX_OP(is_u32_u32, uint32_t, uint32_t)
INDEX_OP(is_u32_f32, uint32_t, float)
INDEX_OP(is_u32_f16, uint32_t, half)
#if defined(__HAVE_BFLOAT__)
INDEX_OP(is_u32_bf16, uint32_t, bfloat)
#endif
+INDEX_OP(is_u8_u8, uint8_t, uint8_t)
+INDEX_OP(is_u8_u32, uint8_t, uint32_t)
INDEX_OP(is_u8_f32, uint8_t, float)
INDEX_OP(is_u8_f16, uint8_t, half)
#if defined(__HAVE_BFLOAT__)
diff --git a/candle-transformers/src/models/chinese_clip/mod.rs b/candle-transformers/src/models/chinese_clip/mod.rs
index 88472f0b..0f6eedd0 100644
--- a/candle-transformers/src/models/chinese_clip/mod.rs
+++ b/candle-transformers/src/models/chinese_clip/mod.rs
@@ -171,7 +171,8 @@ impl ChineseClipModel {
) -> Result<Tensor> {
let output = self
.text_model
- .forward(input_ids, token_type_ids, attention_mask)?;
+ .forward(input_ids, token_type_ids, attention_mask)?
+ .contiguous()?;
self.text_projection.forward(&output)
}