diff options
-rw-r--r-- | candle-core/src/metal_backend/mod.rs | 11 | ||||
-rw-r--r-- | candle-metal-kernels/src/indexing.metal | 4 | ||||
-rw-r--r-- | candle-transformers/src/models/chinese_clip/mod.rs | 3 |
3 files changed, 16 insertions, 2 deletions
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 34931c9d..de107a61 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1237,7 +1237,7 @@ impl BackendStorage for MetalStorage { let dst_el = ids_l.shape().elem_count(); let dtype = self.dtype; let device = self.device(); - let buffer = device.new_buffer(dst_el, dtype, "index_select")?; + let buffer = device.new_buffer(dst_el, dtype, "gather")?; let name = match (ids.dtype, self.dtype) { (DType::U32, DType::F32) => "gather_u32_f32", (DType::U32, DType::F16) => "gather_u32_f16", @@ -1324,14 +1324,23 @@ impl BackendStorage for MetalStorage { let device = self.device(); let buffer = device.new_buffer(dst_el, dtype, "index_select")?; let name = match (ids.dtype, self.dtype) { + (DType::U8, DType::U8) => "is_u8_u8", + (DType::U8, DType::U32) => "is_u8_u32", + (DType::U8, DType::I64) => "is_u8_i64", (DType::U8, DType::BF16) => "is_u8_bf16", (DType::U8, DType::F32) => "is_u8_f32", (DType::U8, DType::F16) => "is_u8_f16", + (DType::U32, DType::U8) => "is_u32_u8", + (DType::U32, DType::U32) => "is_u32_u32", + (DType::U32, DType::I64) => "is_u32_i64", (DType::U32, DType::F32) => "is_u32_f32", (DType::U32, DType::F16) => "is_u32_f16", (DType::U32, DType::BF16) => "is_u32_bf16", + (DType::I64, DType::U8) => "is_i64_u8", + (DType::I64, DType::U32) => "is_i64_u32", + (DType::I64, DType::I64) => "is_i64_i64", (DType::I64, DType::F32) => "is_i64_f32", (DType::I64, DType::F16) => "is_i64_f16", (DType::I64, DType::BF16) => "is_i64_bf16", diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 9eee97ca..c14f2c1f 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -193,12 +193,16 @@ INDEX_OP(is_i64_f16, int64_t, half) INDEX_OP(is_i64_bf16, int64_t, bfloat) #endif +INDEX_OP(is_u32_u8, uint32_t, uint8_t) +INDEX_OP(is_u32_u32, uint32_t, uint32_t) INDEX_OP(is_u32_f32, uint32_t, float) INDEX_OP(is_u32_f16, uint32_t, half) #if defined(__HAVE_BFLOAT__) INDEX_OP(is_u32_bf16, uint32_t, bfloat) #endif +INDEX_OP(is_u8_u8, uint8_t, uint8_t) +INDEX_OP(is_u8_u32, uint8_t, uint32_t) INDEX_OP(is_u8_f32, uint8_t, float) INDEX_OP(is_u8_f16, uint8_t, half) #if defined(__HAVE_BFLOAT__) diff --git a/candle-transformers/src/models/chinese_clip/mod.rs b/candle-transformers/src/models/chinese_clip/mod.rs index 88472f0b..0f6eedd0 100644 --- a/candle-transformers/src/models/chinese_clip/mod.rs +++ b/candle-transformers/src/models/chinese_clip/mod.rs @@ -171,7 +171,8 @@ impl ChineseClipModel { ) -> Result<Tensor> { let output = self .text_model - .forward(input_ids, token_type_ids, attention_mask)?; + .forward(input_ids, token_type_ids, attention_mask)? + .contiguous()?; self.text_projection.forward(&output) } |