summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/lib.rs2
-rw-r--r--candle-metal-kernels/src/utils.rs10
2 files changed, 4 insertions, 8 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index be616009..222ae8ad 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -6,7 +6,7 @@ use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::RwLock;
-mod utils;
+pub mod utils;
pub use utils::BufferOffset;
use utils::{get_block_dims, linear_split, EncoderProvider};
diff --git a/candle-metal-kernels/src/utils.rs b/candle-metal-kernels/src/utils.rs
index d2cc09f4..0092ecfa 100644
--- a/candle-metal-kernels/src/utils.rs
+++ b/candle-metal-kernels/src/utils.rs
@@ -24,7 +24,7 @@ pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (M
}
// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96
-pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
+pub fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
let mut pows0 = 0u64;
let mut pows1 = 0u64;
let mut pows2 = 0u64;
@@ -61,18 +61,14 @@ pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
}
}
-pub(crate) fn set_param<P: EncoderParam>(
- encoder: &ComputeCommandEncoderRef,
- position: u64,
- data: P,
-) {
+pub fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
<P as EncoderParam>::set_param(encoder, position, data)
}
/// Helper functions to create the various objects on the compute command encoder
/// on a single line.
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
-pub(crate) trait EncoderParam {
+pub trait EncoderParam {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
}
macro_rules! primitive {