diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-10-27 15:20:37 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-27 15:20:37 +0100 |
commit | 0e2c8c17fba0c1ba720e3bb50d2d4dec19cef07c (patch) | |
tree | daa4ae5e627a95c6c54cd8eabec32e15d71309be /candle-metal-kernels | |
parent | 594d984f9cf79207f3beb6114ddf73cbc8427b56 (diff) | |
download | candle-0e2c8c17fba0c1ba720e3bb50d2d4dec19cef07c.tar.gz candle-0e2c8c17fba0c1ba720e3bb50d2d4dec19cef07c.tar.bz2 candle-0e2c8c17fba0c1ba720e3bb50d2d4dec19cef07c.zip |
UG metal integration. (#2580)
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 2 | ||||
-rw-r--r-- | candle-metal-kernels/src/utils.rs | 10 |
2 files changed, 4 insertions, 8 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index be616009..222ae8ad 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -6,7 +6,7 @@ use std::collections::HashMap; use std::ffi::c_void; use std::sync::RwLock; -mod utils; +pub mod utils; pub use utils::BufferOffset; use utils::{get_block_dims, linear_split, EncoderProvider}; diff --git a/candle-metal-kernels/src/utils.rs b/candle-metal-kernels/src/utils.rs index d2cc09f4..0092ecfa 100644 --- a/candle-metal-kernels/src/utils.rs +++ b/candle-metal-kernels/src/utils.rs @@ -24,7 +24,7 @@ pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (M } // https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96 -pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize { +pub fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize { let mut pows0 = 0u64; let mut pows1 = 0u64; let mut pows2 = 0u64; @@ -61,18 +61,14 @@ pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize { } } -pub(crate) fn set_param<P: EncoderParam>( - encoder: &ComputeCommandEncoderRef, - position: u64, - data: P, -) { +pub fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) { <P as EncoderParam>::set_param(encoder, position, data) } /// Helper functions to create the various objects on the compute command encoder /// on a single line. /// Prevents getting wrong some arguments number and mixing length and size in bytes. -pub(crate) trait EncoderParam { +pub trait EncoderParam { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self); } macro_rules! primitive { |