diff options
Diffstat (limited to 'candle-pyo3')
-rw-r--r-- | candle-pyo3/Cargo.toml | 4 | ||||
-rw-r--r-- | candle-pyo3/src/lib.rs | 47 | ||||
-rw-r--r-- | candle-pyo3/test.py | 4 |
3 files changed, 52 insertions, 3 deletions
diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index fd2890f6..37244914 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -18,3 +18,7 @@ crate-type = ["cdylib"] candle = { path = "../candle-core", default-features=false } pyo3 = { version = "0.19.0", features = ["extension-module"] } half = { version = "2.3.1", features = ["num-traits"] } + +[features] +default = ["cuda"] +cuda = ["candle/cuda"] diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index d5d472d5..62eb21e8 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -4,7 +4,7 @@ use pyo3::types::PyTuple; use half::{bf16, f16}; -use ::candle::{DType, Device::Cpu, Tensor, WithDType}; +use ::candle::{DType, Device, Tensor, WithDType}; pub fn wrap_err(err: ::candle::Error) -> PyErr { PyErr::new::<PyValueError, _>(format!("{err:?}")) @@ -30,7 +30,7 @@ impl<'source> FromPyObject<'source> for PyDType { use std::str::FromStr; let dtype: &str = ob.extract()?; let dtype = DType::from_str(dtype) - .map_err(|_| PyTypeError::new_err(format!("invalid dtype {dtype}")))?; + .map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?; Ok(Self(dtype)) } } @@ -41,6 +41,43 @@ impl ToPyObject for PyDType { } } +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum PyDevice { + Cpu, + Cuda, +} + +impl PyDevice { + fn from_device(device: Device) -> Self { + match device { + Device::Cpu => Self::Cpu, + Device::Cuda(_) => Self::Cuda, + } + } +} + +impl<'source> FromPyObject<'source> for PyDevice { + fn extract(ob: &'source PyAny) -> PyResult<Self> { + let device: &str = ob.extract()?; + let device = match device { + "cpu" => PyDevice::Cpu, + "cuda" => PyDevice::Cuda, + _ => Err(PyTypeError::new_err(format!("invalid device '{device}'")))?, + }; + Ok(device) + } +} + +impl ToPyObject for PyDevice { + fn to_object(&self, py: Python<'_>) -> PyObject { + let str = match self { + PyDevice::Cpu => "cpu", + PyDevice::Cuda => "cuda", + }; + str.to_object(py) + } +} + trait PyWithDType: WithDType { fn to_py(&self, py: Python<'_>) -> PyObject; } @@ -83,6 +120,7 @@ impl PyTensor { #[new] // TODO: Handle arbitrary input dtype and shape. fn new(py: Python<'_>, vs: PyObject) -> PyResult<Self> { + use Device::Cpu; let tensor = if let Ok(vs) = vs.extract::<u32>(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? } else if let Ok(vs) = vs.extract::<Vec<u32>>(py) { @@ -156,6 +194,11 @@ impl PyTensor { } #[getter] + fn device(&self, py: Python<'_>) -> PyObject { + PyDevice::from_device(self.0.device()).to_object(py) + } + + #[getter] fn rank(&self) -> usize { self.0.rank() } diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py index 1d792de5..8f906060 100644 --- a/candle-pyo3/test.py +++ b/candle-pyo3/test.py @@ -2,12 +2,14 @@ import candle t = candle.Tensor(42.0) print(t) -print("shape", t.shape, t.rank) +print(t.shape, t.rank, t.device) print(t + t) t = candle.Tensor([3.0, 1, 4, 1, 5, 9, 2, 6]) print(t) print(t+t) + t = t.reshape([2, 4]) print(t.matmul(t.t())) + print(t.to_dtype("u8")) |