diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-24 06:34:37 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-24 06:34:37 +0100 |
commit | 7bd0faba7592a150f2c29c1f754eb27529d1d6d1 (patch) | |
tree | c601eb68ff67aa8eb330635f00a0c9062b488079 /candle-pyo3 | |
parent | 807e3f9f52a20d2f5d5688f14bfca8f9c4157e2a (diff) | |
download | candle-7bd0faba7592a150f2c29c1f754eb27529d1d6d1.tar.gz candle-7bd0faba7592a150f2c29c1f754eb27529d1d6d1.tar.bz2 candle-7bd0faba7592a150f2c29c1f754eb27529d1d6d1.zip |
Add support for accelerate in the pyo3 bindings. (#1167)
Diffstat (limited to 'candle-pyo3')
-rw-r--r-- | candle-pyo3/Cargo.toml | 4 | ||||
-rw-r--r-- | candle-pyo3/src/lib.rs | 3 | ||||
-rw-r--r-- | candle-pyo3/test.py | 5 |
3 files changed, 11 insertions, 1 deletions
diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index 8bccbcc6..0241d2b2 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -14,16 +14,18 @@ name = "candle" crate-type = ["cdylib"] [dependencies] +accelerate-src = { workspace = true, optional = true } candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" } candle-nn = { path = "../candle-nn", version = "0.3.0" } half = { workspace = true } -pyo3 = { version = "0.19.0", features = ["extension-module"] } intel-mkl-src = { workspace = true, optional = true } +pyo3 = { version = "0.19.0", features = ["extension-module"] } [build-dependencies] pyo3-build-config = "0.19" [features] default = [] +accelerate = ["dep:accelerate-src", "candle/accelerate"] cuda = ["candle/cuda"] mkl = ["dep:intel-mkl-src","candle/mkl"] diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 29f38ff8..e2c8014f 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -11,6 +11,9 @@ use half::{bf16, f16}; #[cfg(feature = "mkl")] extern crate intel_mkl_src; +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType}; pub fn wrap_err(err: ::candle::Error) -> PyErr { diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py index a56ed22c..e4ff772a 100644 --- a/candle-pyo3/test.py +++ b/candle-pyo3/test.py @@ -1,5 +1,10 @@ import candle +print(f"mkl: {candle.utils.has_mkl()}") +print(f"accelerate: {candle.utils.has_accelerate()}") +print(f"num-threads: {candle.utils.get_num_threads()}") +print(f"cuda: {candle.utils.cuda_is_available()}") + t = candle.Tensor(42.0) print(t) print(t.shape, t.rank, t.device) |