summaryrefslogtreecommitdiff
path: root/candle-pyo3/src/lib.rs
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2023-07-02 20:42:55 +0100
committerlaurent <laurent.mazare@gmail.com>2023-07-02 20:42:55 +0100
commitfbfe74caab8835d758a1a2bb9ab1c62c9afd50d5 (patch)
treeb577884df10ad2fe3ca9a401cb48520142856b1e /candle-pyo3/src/lib.rs
parenteb6f7d30b6f8bae64e9958c27bc8f60f251e5c52 (diff)
downloadcandle-fbfe74caab8835d758a1a2bb9ab1c62c9afd50d5.tar.gz
candle-fbfe74caab8835d758a1a2bb9ab1c62c9afd50d5.tar.bz2
candle-fbfe74caab8835d758a1a2bb9ab1c62c9afd50d5.zip
Preliminary pyo3 support for device.
Diffstat (limited to 'candle-pyo3/src/lib.rs')
-rw-r--r--candle-pyo3/src/lib.rs47
1 files changed, 45 insertions, 2 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index d5d472d5..62eb21e8 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -4,7 +4,7 @@ use pyo3::types::PyTuple;
use half::{bf16, f16};
-use ::candle::{DType, Device::Cpu, Tensor, WithDType};
+use ::candle::{DType, Device, Tensor, WithDType};
pub fn wrap_err(err: ::candle::Error) -> PyErr {
PyErr::new::<PyValueError, _>(format!("{err:?}"))
@@ -30,7 +30,7 @@ impl<'source> FromPyObject<'source> for PyDType {
use std::str::FromStr;
let dtype: &str = ob.extract()?;
let dtype = DType::from_str(dtype)
- .map_err(|_| PyTypeError::new_err(format!("invalid dtype {dtype}")))?;
+ .map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?;
Ok(Self(dtype))
}
}
@@ -41,6 +41,43 @@ impl ToPyObject for PyDType {
}
}
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+enum PyDevice {
+ Cpu,
+ Cuda,
+}
+
+impl PyDevice {
+ fn from_device(device: Device) -> Self {
+ match device {
+ Device::Cpu => Self::Cpu,
+ Device::Cuda(_) => Self::Cuda,
+ }
+ }
+}
+
+impl<'source> FromPyObject<'source> for PyDevice {
+ fn extract(ob: &'source PyAny) -> PyResult<Self> {
+ let device: &str = ob.extract()?;
+ let device = match device {
+ "cpu" => PyDevice::Cpu,
+ "cuda" => PyDevice::Cuda,
+ _ => Err(PyTypeError::new_err(format!("invalid device '{device}'")))?,
+ };
+ Ok(device)
+ }
+}
+
+impl ToPyObject for PyDevice {
+ fn to_object(&self, py: Python<'_>) -> PyObject {
+ let str = match self {
+ PyDevice::Cpu => "cpu",
+ PyDevice::Cuda => "cuda",
+ };
+ str.to_object(py)
+ }
+}
+
trait PyWithDType: WithDType {
fn to_py(&self, py: Python<'_>) -> PyObject;
}
@@ -83,6 +120,7 @@ impl PyTensor {
#[new]
// TODO: Handle arbitrary input dtype and shape.
fn new(py: Python<'_>, vs: PyObject) -> PyResult<Self> {
+ use Device::Cpu;
let tensor = if let Ok(vs) = vs.extract::<u32>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<u32>>(py) {
@@ -156,6 +194,11 @@ impl PyTensor {
}
#[getter]
+ fn device(&self, py: Python<'_>) -> PyObject {
+ PyDevice::from_device(self.0.device()).to_object(py)
+ }
+
+ #[getter]
fn rank(&self) -> usize {
self.0.rank()
}