summaryrefslogtreecommitdiff
path: root/candle-core/src/accelerate.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/accelerate.rs')
-rw-r--r--candle-core/src/accelerate.rs111
1 files changed, 111 insertions, 0 deletions
diff --git a/candle-core/src/accelerate.rs b/candle-core/src/accelerate.rs
new file mode 100644
index 00000000..8b0df5c1
--- /dev/null
+++ b/candle-core/src/accelerate.rs
@@ -0,0 +1,111 @@
+#![allow(dead_code)]
+use libc::{c_char, c_double, c_float, c_int};
+
+mod ffi {
+ use super::*;
+ extern "C" {
+ // It would be nice to be able to switch to the NEWLAPACK version of the function but this
+ // seems to trigger some link error. Available function names can be seen here:
+ // /Library/Developer/CommandLineTools/SDKs/MacOSX13.3.sdk/System/Library/Frameworks/Accelerate.framework/Versions/A/Accelerate.tbd
+ #[link_name = "sgemm_"]
+ pub fn sgemm_ffi(
+ transa: *const c_char,
+ transb: *const c_char,
+ m: *const c_int,
+ n: *const c_int,
+ k: *const c_int,
+ alpha: *const c_float,
+ a: *const c_float,
+ lda: *const c_int,
+ b: *const c_float,
+ ldb: *const c_int,
+ beta: *const c_float,
+ c: *mut c_float,
+ ldc: *const c_int,
+ );
+ #[link_name = "dgemm_"]
+ pub fn dgemm_ffi(
+ transa: *const c_char,
+ transb: *const c_char,
+ m: *const c_int,
+ n: *const c_int,
+ k: *const c_int,
+ alpha: *const c_double,
+ a: *const c_double,
+ lda: *const c_int,
+ b: *const c_double,
+ ldb: *const c_int,
+ beta: *const c_double,
+ c: *mut c_double,
+ ldc: *const c_int,
+ );
+ }
+}
+
+#[allow(clippy::too_many_arguments)]
+#[inline]
+pub unsafe fn sgemm(
+ transa: u8,
+ transb: u8,
+ m: i32,
+ n: i32,
+ k: i32,
+ alpha: f32,
+ a: &[f32],
+ lda: i32,
+ b: &[f32],
+ ldb: i32,
+ beta: f32,
+ c: &mut [f32],
+ ldc: i32,
+) {
+ ffi::sgemm_ffi(
+ &(transa as c_char),
+ &(transb as c_char),
+ &m,
+ &n,
+ &k,
+ &alpha,
+ a.as_ptr(),
+ &lda,
+ b.as_ptr(),
+ &ldb,
+ &beta,
+ c.as_mut_ptr(),
+ &ldc,
+ )
+}
+
+#[allow(clippy::too_many_arguments)]
+#[inline]
+pub unsafe fn dgemm(
+ transa: u8,
+ transb: u8,
+ m: i32,
+ n: i32,
+ k: i32,
+ alpha: f64,
+ a: &[f64],
+ lda: i32,
+ b: &[f64],
+ ldb: i32,
+ beta: f64,
+ c: &mut [f64],
+ ldc: i32,
+) {
+ ffi::dgemm_ffi(
+ &(transa as c_char),
+ &(transb as c_char),
+ &m,
+ &n,
+ &k,
+ &alpha,
+ a.as_ptr(),
+ &lda,
+ b.as_ptr(),
+ &ldb,
+ &beta,
+ c.as_mut_ptr(),
+ &ldc,
+ )
+}