diff options
Diffstat (limited to 'candle-core/src/accelerate.rs')
-rw-r--r-- | candle-core/src/accelerate.rs | 111 |
1 files changed, 111 insertions, 0 deletions
diff --git a/candle-core/src/accelerate.rs b/candle-core/src/accelerate.rs new file mode 100644 index 00000000..8b0df5c1 --- /dev/null +++ b/candle-core/src/accelerate.rs @@ -0,0 +1,111 @@ +#![allow(dead_code)] +use libc::{c_char, c_double, c_float, c_int}; + +mod ffi { + use super::*; + extern "C" { + // It would be nice to be able to switch to the NEWLAPACK version of the function but this + // seems to trigger some link error. Available function names can be seen here: + // /Library/Developer/CommandLineTools/SDKs/MacOSX13.3.sdk/System/Library/Frameworks/Accelerate.framework/Versions/A/Accelerate.tbd + #[link_name = "sgemm_"] + pub fn sgemm_ffi( + transa: *const c_char, + transb: *const c_char, + m: *const c_int, + n: *const c_int, + k: *const c_int, + alpha: *const c_float, + a: *const c_float, + lda: *const c_int, + b: *const c_float, + ldb: *const c_int, + beta: *const c_float, + c: *mut c_float, + ldc: *const c_int, + ); + #[link_name = "dgemm_"] + pub fn dgemm_ffi( + transa: *const c_char, + transb: *const c_char, + m: *const c_int, + n: *const c_int, + k: *const c_int, + alpha: *const c_double, + a: *const c_double, + lda: *const c_int, + b: *const c_double, + ldb: *const c_int, + beta: *const c_double, + c: *mut c_double, + ldc: *const c_int, + ); + } +} + +#[allow(clippy::too_many_arguments)] +#[inline] +pub unsafe fn sgemm( + transa: u8, + transb: u8, + m: i32, + n: i32, + k: i32, + alpha: f32, + a: &[f32], + lda: i32, + b: &[f32], + ldb: i32, + beta: f32, + c: &mut [f32], + ldc: i32, +) { + ffi::sgemm_ffi( + &(transa as c_char), + &(transb as c_char), + &m, + &n, + &k, + &alpha, + a.as_ptr(), + &lda, + b.as_ptr(), + &ldb, + &beta, + c.as_mut_ptr(), + &ldc, + ) +} + +#[allow(clippy::too_many_arguments)] +#[inline] +pub unsafe fn dgemm( + transa: u8, + transb: u8, + m: i32, + n: i32, + k: i32, + alpha: f64, + a: &[f64], + lda: i32, + b: &[f64], + ldb: i32, + beta: f64, + c: &mut [f64], + ldc: i32, +) { + ffi::dgemm_ffi( + &(transa as c_char), + &(transb as c_char), + &m, + &n, + &k, + &alpha, + a.as_ptr(), + &lda, + b.as_ptr(), + &ldb, + &beta, + c.as_mut_ptr(), + &ldc, + ) +} |