diff options
-rw-r--r-- | candle-core/src/metal_backend.rs | 23 | ||||
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 6 | ||||
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 21 | ||||
-rw-r--r-- | candle-nn/src/ops.rs | 3 |
4 files changed, 33 insertions, 20 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index f570d2c5..424b29d9 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -482,11 +482,14 @@ impl BackendStorage for MetalStorage { } fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> { - if !(sum_dims.len() == 1 - && sum_dims[0] == layout.shape().rank() - 1 - && layout.stride()[sum_dims[0]] == 1) - { - crate::bail!("Non last dim reduce op not supported yet"); + if sum_dims.len() != 1 { + crate::bail!("reduce {op:?} over multiple dimensions is not implemented yet."); + } + if sum_dims[0] != layout.shape().rank() - 1 { + crate::bail!("Non last dim reduce op {op:?} not implemented yet"); + } + if layout.stride()[sum_dims[0]] != 1 { + crate::bail!("Non contiguous reduce op {op:?} not implemented yet"); } let device = self.device.clone(); @@ -524,7 +527,7 @@ impl BackendStorage for MetalStorage { } let dtype = if return_index { DType::U32 } else { self.dtype }; if dtype == DType::U32 { - crate::bail!("Implement return index reduce op"); + crate::bail!("reduce op {name} is not implemented yet."); } let buffer = device.new_buffer(dst_el, dtype, "reduce")?; let command_buffer = self.device.command_buffer()?; @@ -790,12 +793,16 @@ impl BackendStorage for MetalStorage { let buffer = self.device.new_buffer(el, dtype, "where")?; let command_buffer = self.device.command_buffer()?; if t.dtype() != f.dtype() { - crate::bail!("Invalid ternary different dtypes for values"); + crate::bail!( + "Invalid where: different dtypes for values {:?} != {:?}", + t.dtype(), + f.dtype() + ); } let name = match (self.dtype, t.dtype()) { (DType::U8, DType::F32) => "where_u8_f32", (DType::U8, DType::F16) => "where_u8_f16", - (left, right) => crate::bail!("Ternary {left:?} - {right:?} not implemented"), + (left, right) => crate::bail!("where {left:?} - {right:?} not implemented"), }; candle_metal_kernels::call_where_cond_strided( &device.device, diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index a23aa47c..f2db171e 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -597,6 +597,7 @@ pub fn call_last_softmax( length: usize, elements_to_sum: usize, input: &Buffer, + input_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; @@ -604,7 +605,10 @@ pub fn call_last_softmax( encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, elements_to_sum, input, output)); + set_params!( + encoder, + (length, elements_to_sum, (input, input_offset), output) + ); let out_length = length / elements_to_sum; diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 75c2f013..9c9475a2 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -312,7 +312,7 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> { &device, command_buffer, &kernels, - "affine_float", + "affine_f32", size, &input, &output, @@ -346,7 +346,7 @@ fn run_affine_strided<T: Clone>( &device, command_buffer, &kernels, - "affine_float_strided", + "affine_f32_strided", shape, &input, strides, @@ -608,6 +608,7 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta v.len(), last_dim, &input, + 0, &output, ) .unwrap(); @@ -622,7 +623,7 @@ fn reduce_sum() { let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let out_length = 1; - let results = run_reduce(&v, out_length, "fast_sum_float"); + let results = run_reduce(&v, out_length, "fast_sum_f32"); assert_eq!(approx(results, 4), vec![21.0]); } @@ -631,7 +632,7 @@ fn reduce_sum2() { let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let out_length = 2; - let results = run_reduce(&v, out_length, "fast_sum_float"); + let results = run_reduce(&v, out_length, "fast_sum_f32"); assert_eq!(approx(results, 4), vec![6.0, 15.0]); } @@ -639,7 +640,7 @@ fn reduce_sum2() { fn softmax() { let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let last_dim = 6; - let results = run_softmax(&v, last_dim, "softmax_float"); + let results = run_softmax(&v, last_dim, "softmax_f32"); assert_eq!( approx(results, 4), vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] @@ -651,7 +652,7 @@ fn softmax() { for i in 0..n { v[i * last_dim] = 20.0; } - let results = run_softmax(&v, last_dim, "softmax_float"); + let results = run_softmax(&v, last_dim, "softmax_f32"); let results = approx(results, 4); println!("{results:?}"); assert_eq!( @@ -665,7 +666,7 @@ fn softmax() { let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]; let last_dim = 6; - let results = run_softmax(&v, last_dim, "softmax_float"); + let results = run_softmax(&v, last_dim, "softmax_f32"); assert_eq!( approx(results, 4), vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] @@ -673,7 +674,7 @@ fn softmax() { let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let last_dim = 3; - let results = run_softmax(&v, last_dim, "softmax_float"); + let results = run_softmax(&v, last_dim, "softmax_f32"); assert_eq!( approx(results, 4), vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652] @@ -684,7 +685,7 @@ fn softmax() { .map(|v| f16::from_f32(*v)) .collect::<Vec<_>>(); let last_dim = 6; - let results = run_softmax(&v, last_dim, "softmax_half"); + let results = run_softmax(&v, last_dim, "softmax_f16"); assert_eq!( approx_f16(results, 4), vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338] @@ -695,7 +696,7 @@ fn softmax() { .map(|v| bf16::from_f32(*v)) .collect::<Vec<_>>(); let last_dim = 6; - let results = run_softmax(&v, last_dim, "softmax_bfloat"); + let results = run_softmax(&v, last_dim, "softmax_bf16"); assert_eq!( approx_bf16(results, 4), vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328] diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 94380f12..816eff42 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -220,7 +220,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { }; let n = layout.stride().len(); - if !(layout.is_contiguous() && layout.stride()[n - 1] == 1 && layout.start_offset() == 0) { + if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) { candle::bail!("Non contiguous softmax-last-dim is not implemented"); } @@ -235,6 +235,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { elem_count, last_dim, storage.buffer(), + layout.start_offset() * storage.dtype().size_in_bytes(), &mut output, ) .unwrap(); |