summaryrefslogtreecommitdiff
path: root/candle-pyo3
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-24 06:34:37 +0100
committerGitHub <noreply@github.com>2023-10-24 06:34:37 +0100
commit7bd0faba7592a150f2c29c1f754eb27529d1d6d1 (patch)
treec601eb68ff67aa8eb330635f00a0c9062b488079 /candle-pyo3
parent807e3f9f52a20d2f5d5688f14bfca8f9c4157e2a (diff)
downloadcandle-7bd0faba7592a150f2c29c1f754eb27529d1d6d1.tar.gz
candle-7bd0faba7592a150f2c29c1f754eb27529d1d6d1.tar.bz2
candle-7bd0faba7592a150f2c29c1f754eb27529d1d6d1.zip
Add support for accelerate in the pyo3 bindings. (#1167)
Diffstat (limited to 'candle-pyo3')
-rw-r--r--candle-pyo3/Cargo.toml4
-rw-r--r--candle-pyo3/src/lib.rs3
-rw-r--r--candle-pyo3/test.py5
3 files changed, 11 insertions, 1 deletions
diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml
index 8bccbcc6..0241d2b2 100644
--- a/candle-pyo3/Cargo.toml
+++ b/candle-pyo3/Cargo.toml
@@ -14,16 +14,18 @@ name = "candle"
crate-type = ["cdylib"]
[dependencies]
+accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
half = { workspace = true }
-pyo3 = { version = "0.19.0", features = ["extension-module"] }
intel-mkl-src = { workspace = true, optional = true }
+pyo3 = { version = "0.19.0", features = ["extension-module"] }
[build-dependencies]
pyo3-build-config = "0.19"
[features]
default = []
+accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"]
mkl = ["dep:intel-mkl-src","candle/mkl"]
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 29f38ff8..e2c8014f 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -11,6 +11,9 @@ use half::{bf16, f16};
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
pub fn wrap_err(err: ::candle::Error) -> PyErr {
diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py
index a56ed22c..e4ff772a 100644
--- a/candle-pyo3/test.py
+++ b/candle-pyo3/test.py
@@ -1,5 +1,10 @@
import candle
+print(f"mkl: {candle.utils.has_mkl()}")
+print(f"accelerate: {candle.utils.has_accelerate()}")
+print(f"num-threads: {candle.utils.get_num_threads()}")
+print(f"cuda: {candle.utils.cuda_is_available()}")
+
t = candle.Tensor(42.0)
print(t)
print(t.shape, t.rank, t.device)