diff options
author | Anubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com> | 2024-11-29 03:30:21 +0530 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-11-28 23:00:21 +0100 |
commit | 54e7fc3c97a6d40e459cee4d4bf2eff5c82390da (patch) | |
tree | 8ef4da0e255884de729b8c70fdf642b7c807d631 /candle-metal-kernels | |
parent | 23ed8a9ded155df7b5961d6a5ae12b4e8096a9c2 (diff) | |
download | candle-54e7fc3c97a6d40e459cee4d4bf2eff5c82390da.tar.gz candle-54e7fc3c97a6d40e459cee4d4bf2eff5c82390da.tar.bz2 candle-54e7fc3c97a6d40e459cee4d4bf2eff5c82390da.zip |
Lint fixes introduced with Rust 1.83 (#2646)
* Fixes for lint errors introduced with Rust 1.83
* rustfmt
* Fix more lints.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 20 | ||||
-rw-r--r-- | candle-metal-kernels/src/utils.rs | 17 |
2 files changed, 20 insertions, 17 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 0843cc11..5f948cbf 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -372,7 +372,7 @@ pub fn call_unary_contiguous_tiled( let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); let tile_size = 2; - let tiles = (length + tile_size - 1) / tile_size; + let tiles = length.div_ceil(tile_size); encoder.set_compute_pipeline_state(&pipeline); @@ -594,7 +594,7 @@ pub fn call_reduce_contiguous( let width = std::cmp::min( pipeline.max_total_threads_per_threadgroup(), - (elements_to_sum as u64 + 2 - 1) / 2, + (elements_to_sum as u64).div_ceil(2), ) .next_power_of_two(); @@ -1735,7 +1735,7 @@ pub fn call_sdpa_full( } }; - let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?; + let pipeline = kernels.load_pipeline(device, Source::Sdpa, name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); @@ -1759,16 +1759,16 @@ pub fn call_sdpa_full( let ldo = dk; let tn = 1; - let tm = (m + BM - 1) / BM; + let tm = m.div_ceil(BM); let b_stride_q = dk * qseq; let b_stride_k = dk * qseq; let b_stride_v = dk * qseq; let b_stride_o = dk * qseq; let swizzle_log = 0; - let gemm_n_iterations_aligned = (n + BN - 1) / BN; - let gemm_k_iterations_aligned = (k + bk - 1) / bk; - let gemm_sv_m_block_iterations = (m + BM - 1) / BM; + let gemm_n_iterations_aligned = n.div_ceil(BN); + let gemm_k_iterations_aligned = k.div_ceil(*bk); + let gemm_sv_m_block_iterations = m.div_ceil(BM); let batch_ndim = batch_shape.len(); let alpha = if softcapping != 1. { @@ -1906,7 +1906,7 @@ pub fn call_sdpa_vector( alpha }; - let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?; + let pipeline = kernels.load_pipeline(device, Source::Sdpa, name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); @@ -1933,7 +1933,7 @@ pub fn call_sdpa_vector( let grid_dims = MTLSize { width: 1, height: b as u64, - depth: 1 as u64, + depth: 1_u64, }; let group_dims = MTLSize { width: 1024, @@ -2320,7 +2320,7 @@ pub fn call_quantized_matmul_mv_t( } fn divide(m: usize, b: usize) -> NSUInteger { - ((m + b - 1) / b) as NSUInteger + m.div_ceil(b) as NSUInteger } #[allow(clippy::too_many_arguments)] diff --git a/candle-metal-kernels/src/utils.rs b/candle-metal-kernels/src/utils.rs index 0092ecfa..025808d7 100644 --- a/candle-metal-kernels/src/utils.rs +++ b/candle-metal-kernels/src/utils.rs @@ -8,7 +8,7 @@ use std::ffi::c_void; pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) { let size = length as u64; let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size); - let count = (size + width - 1) / width; + let count = size.div_ceil(width); let thread_group_count = MTLSize { width: count, height: 1, @@ -128,7 +128,7 @@ impl EncoderParam for (&Buffer, usize) { } } -impl<'a> EncoderParam for &BufferOffset<'a> { +impl EncoderParam for &BufferOffset<'_> { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { encoder.set_buffer(position, Some(data.buffer), data.offset_in_bytes as u64); } @@ -169,7 +169,7 @@ pub struct WrappedEncoder<'a> { end_encoding_on_drop: bool, } -impl<'a> Drop for WrappedEncoder<'a> { +impl Drop for WrappedEncoder<'_> { fn drop(&mut self) { if self.end_encoding_on_drop { self.inner.end_encoding() @@ -177,14 +177,15 @@ impl<'a> Drop for WrappedEncoder<'a> { } } -impl<'a> AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'a> { +impl AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'_> { fn as_ref(&self) -> &metal::ComputeCommandEncoderRef { self.inner } } impl EncoderProvider for &metal::CommandBuffer { - type Encoder<'a> = WrappedEncoder<'a> + type Encoder<'a> + = WrappedEncoder<'a> where Self: 'a; fn encoder(&self) -> Self::Encoder<'_> { @@ -196,7 +197,8 @@ impl EncoderProvider for &metal::CommandBuffer { } impl EncoderProvider for &metal::CommandBufferRef { - type Encoder<'a> = WrappedEncoder<'a> + type Encoder<'a> + = WrappedEncoder<'a> where Self: 'a; fn encoder(&self) -> Self::Encoder<'_> { @@ -208,7 +210,8 @@ impl EncoderProvider for &metal::CommandBufferRef { } impl EncoderProvider for &ComputeCommandEncoderRef { - type Encoder<'a> = WrappedEncoder<'a> + type Encoder<'a> + = WrappedEncoder<'a> where Self: 'a; fn encoder(&self) -> Self::Encoder<'_> { |