summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-metal-kernels/src/lib.rs')
-rw-r--r--candle-metal-kernels/src/lib.rs20
1 files changed, 10 insertions, 10 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 0843cc11..5f948cbf 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -372,7 +372,7 @@ pub fn call_unary_contiguous_tiled(
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let tile_size = 2;
- let tiles = (length + tile_size - 1) / tile_size;
+ let tiles = length.div_ceil(tile_size);
encoder.set_compute_pipeline_state(&pipeline);
@@ -594,7 +594,7 @@ pub fn call_reduce_contiguous(
let width = std::cmp::min(
pipeline.max_total_threads_per_threadgroup(),
- (elements_to_sum as u64 + 2 - 1) / 2,
+ (elements_to_sum as u64).div_ceil(2),
)
.next_power_of_two();
@@ -1735,7 +1735,7 @@ pub fn call_sdpa_full(
}
};
- let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?;
+ let pipeline = kernels.load_pipeline(device, Source::Sdpa, name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
@@ -1759,16 +1759,16 @@ pub fn call_sdpa_full(
let ldo = dk;
let tn = 1;
- let tm = (m + BM - 1) / BM;
+ let tm = m.div_ceil(BM);
let b_stride_q = dk * qseq;
let b_stride_k = dk * qseq;
let b_stride_v = dk * qseq;
let b_stride_o = dk * qseq;
let swizzle_log = 0;
- let gemm_n_iterations_aligned = (n + BN - 1) / BN;
- let gemm_k_iterations_aligned = (k + bk - 1) / bk;
- let gemm_sv_m_block_iterations = (m + BM - 1) / BM;
+ let gemm_n_iterations_aligned = n.div_ceil(BN);
+ let gemm_k_iterations_aligned = k.div_ceil(*bk);
+ let gemm_sv_m_block_iterations = m.div_ceil(BM);
let batch_ndim = batch_shape.len();
let alpha = if softcapping != 1. {
@@ -1906,7 +1906,7 @@ pub fn call_sdpa_vector(
alpha
};
- let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?;
+ let pipeline = kernels.load_pipeline(device, Source::Sdpa, name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
@@ -1933,7 +1933,7 @@ pub fn call_sdpa_vector(
let grid_dims = MTLSize {
width: 1,
height: b as u64,
- depth: 1 as u64,
+ depth: 1_u64,
};
let group_dims = MTLSize {
width: 1024,
@@ -2320,7 +2320,7 @@ pub fn call_quantized_matmul_mv_t(
}
fn divide(m: usize, b: usize) -> NSUInteger {
- ((m + b - 1) / b) as NSUInteger
+ m.div_ceil(b) as NSUInteger
}
#[allow(clippy::too_many_arguments)]