diff options
author | laurent <laurent.mazare@gmail.com> | 2023-07-02 20:12:26 +0100 |
---|---|---|
committer | laurent <laurent.mazare@gmail.com> | 2023-07-02 20:12:26 +0100 |
commit | 78871ffe38a9ae0b6e4a905ab7d0329b7f3567c3 (patch) | |
tree | 5cc4b6b8c298f227e02d8567992e5bf2f112e261 /candle-pyo3/src/lib.rs | |
parent | 65e069384c4c1995467592f9f6317e9b6c49981d (diff) | |
download | candle-78871ffe38a9ae0b6e4a905ab7d0329b7f3567c3.tar.gz candle-78871ffe38a9ae0b6e4a905ab7d0329b7f3567c3.tar.bz2 candle-78871ffe38a9ae0b6e4a905ab7d0329b7f3567c3.zip |
Add dtype support.
Diffstat (limited to 'candle-pyo3/src/lib.rs')
-rw-r--r-- | candle-pyo3/src/lib.rs | 35 |
1 files changed, 29 insertions, 6 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index b1504ada..7da91b3f 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1,6 +1,6 @@ use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; -use pyo3::types::{PyString, PyTuple}; +use pyo3::types::PyTuple; use half::{bf16, f16}; @@ -22,13 +22,32 @@ impl std::ops::Deref for PyTensor { } } -trait PyDType: WithDType { +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +struct PyDType(DType); + +impl<'source> FromPyObject<'source> for PyDType { + fn extract(ob: &'source PyAny) -> PyResult<Self> { + use std::str::FromStr; + let dtype: &str = ob.extract()?; + let dtype = DType::from_str(dtype) + .map_err(|_| PyTypeError::new_err(format!("invalid dtype {dtype}")))?; + Ok(Self(dtype)) + } +} + +impl ToPyObject for PyDType { + fn to_object(&self, py: Python<'_>) -> PyObject { + self.0.as_str().to_object(py) + } +} + +trait PyWithDType: WithDType { fn to_py(&self, py: Python<'_>) -> PyObject; } macro_rules! pydtype { ($ty:ty, $conv:expr) => { - impl PyDType for $ty { + impl PyWithDType for $ty { fn to_py(&self, py: Python<'_>) -> PyObject { $conv(*self).to_object(py) } @@ -45,7 +64,7 @@ pydtype!(f64, |v| v); // TODO: Something similar to this should probably be a part of candle core. trait MapDType { type Output; - fn f<T: PyDType>(&self, t: &Tensor) -> PyResult<Self::Output>; + fn f<T: PyWithDType>(&self, t: &Tensor) -> PyResult<Self::Output>; fn map(&self, t: &Tensor) -> PyResult<Self::Output> { match t.dtype() { @@ -83,7 +102,7 @@ impl PyTensor { struct M<'a>(Python<'a>); impl<'a> MapDType for M<'a> { type Output = PyObject; - fn f<T: PyDType>(&self, t: &Tensor) -> PyResult<Self::Output> { + fn f<T: PyWithDType>(&self, t: &Tensor) -> PyResult<Self::Output> { match t.rank() { 0 => Ok(t.to_scalar::<T>().map_err(wrap_err)?.to_py(self.0)), 1 => { @@ -133,7 +152,7 @@ impl PyTensor { #[getter] fn dtype(&self, py: Python<'_>) -> PyObject { - PyString::new(py, self.0.dtype().as_str()).to_object(py) + PyDType(self.0.dtype()).to_object(py) } #[getter] @@ -269,6 +288,10 @@ impl PyTensor { fn copy(&self) -> PyResult<Self> { Ok(PyTensor(self.0.copy().map_err(wrap_err)?)) } + + fn to_dtype(&self, dtype: PyDType) -> PyResult<Self> { + Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?)) + } } /// Concatenate the tensors across one axis. |