summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-02 20:21:37 +0100
committerGitHub <noreply@github.com>2023-07-02 20:21:37 +0100
commiteb6f7d30b6f8bae64e9958c27bc8f60f251e5c52 (patch)
tree6fb82508ce85011c8f1f68fb784ef496e5e35ad2
parent65e069384c4c1995467592f9f6317e9b6c49981d (diff)
parentbdb257ceabcbec1a3401b9f02c817d4f42c46a1e (diff)
downloadcandle-eb6f7d30b6f8bae64e9958c27bc8f60f251e5c52.tar.gz
candle-eb6f7d30b6f8bae64e9958c27bc8f60f251e5c52.tar.bz2
candle-eb6f7d30b6f8bae64e9958c27bc8f60f251e5c52.zip
Merge pull request #54 from LaurentMazare/more-pyo3-2
Add dtype support in the pyo3 bindings.
-rw-r--r--candle-core/src/dtype.rs18
-rw-r--r--candle-pyo3/src/lib.rs41
-rw-r--r--candle-pyo3/test.py1
3 files changed, 54 insertions, 6 deletions
diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs
index e6785491..8ce70f64 100644
--- a/candle-core/src/dtype.rs
+++ b/candle-core/src/dtype.rs
@@ -10,6 +10,24 @@ pub enum DType {
F64,
}
+#[derive(Debug, PartialEq, Eq)]
+pub struct DTypeParseError;
+
+impl std::str::FromStr for DType {
+ type Err = DTypeParseError;
+ fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
+ match s {
+ "u8" => Ok(Self::U8),
+ "u32" => Ok(Self::U32),
+ "bf16" => Ok(Self::BF16),
+ "f16" => Ok(Self::F16),
+ "f32" => Ok(Self::F32),
+ "f64" => Ok(Self::F64),
+ _ => Err(DTypeParseError),
+ }
+ }
+}
+
impl DType {
pub fn as_str(&self) -> &'static str {
match self {
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index b1504ada..d5d472d5 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -1,6 +1,6 @@
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
-use pyo3::types::{PyString, PyTuple};
+use pyo3::types::PyTuple;
use half::{bf16, f16};
@@ -22,13 +22,32 @@ impl std::ops::Deref for PyTensor {
}
}
-trait PyDType: WithDType {
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+struct PyDType(DType);
+
+impl<'source> FromPyObject<'source> for PyDType {
+ fn extract(ob: &'source PyAny) -> PyResult<Self> {
+ use std::str::FromStr;
+ let dtype: &str = ob.extract()?;
+ let dtype = DType::from_str(dtype)
+ .map_err(|_| PyTypeError::new_err(format!("invalid dtype {dtype}")))?;
+ Ok(Self(dtype))
+ }
+}
+
+impl ToPyObject for PyDType {
+ fn to_object(&self, py: Python<'_>) -> PyObject {
+ self.0.as_str().to_object(py)
+ }
+}
+
+trait PyWithDType: WithDType {
fn to_py(&self, py: Python<'_>) -> PyObject;
}
macro_rules! pydtype {
($ty:ty, $conv:expr) => {
- impl PyDType for $ty {
+ impl PyWithDType for $ty {
fn to_py(&self, py: Python<'_>) -> PyObject {
$conv(*self).to_object(py)
}
@@ -45,7 +64,7 @@ pydtype!(f64, |v| v);
// TODO: Something similar to this should probably be a part of candle core.
trait MapDType {
type Output;
- fn f<T: PyDType>(&self, t: &Tensor) -> PyResult<Self::Output>;
+ fn f<T: PyWithDType>(&self, t: &Tensor) -> PyResult<Self::Output>;
fn map(&self, t: &Tensor) -> PyResult<Self::Output> {
match t.dtype() {
@@ -83,7 +102,7 @@ impl PyTensor {
struct M<'a>(Python<'a>);
impl<'a> MapDType for M<'a> {
type Output = PyObject;
- fn f<T: PyDType>(&self, t: &Tensor) -> PyResult<Self::Output> {
+ fn f<T: PyWithDType>(&self, t: &Tensor) -> PyResult<Self::Output> {
match t.rank() {
0 => Ok(t.to_scalar::<T>().map_err(wrap_err)?.to_py(self.0)),
1 => {
@@ -133,7 +152,7 @@ impl PyTensor {
#[getter]
fn dtype(&self, py: Python<'_>) -> PyObject {
- PyString::new(py, self.0.dtype().as_str()).to_object(py)
+ PyDType(self.0.dtype()).to_object(py)
}
#[getter]
@@ -269,6 +288,10 @@ impl PyTensor {
fn copy(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.copy().map_err(wrap_err)?))
}
+
+ fn to_dtype(&self, dtype: PyDType) -> PyResult<Self> {
+ Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?))
+ }
}
/// Concatenate the tensors across one axis.
@@ -286,10 +309,16 @@ fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
Ok(PyTensor(tensor))
}
+#[pyfunction]
+fn tensor(py: Python<'_>, vs: PyObject) -> PyResult<PyTensor> {
+ PyTensor::new(py, vs)
+}
+
#[pymodule]
fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<PyTensor>()?;
m.add_function(wrap_pyfunction!(cat, m)?)?;
+ m.add_function(wrap_pyfunction!(tensor, m)?)?;
m.add_function(wrap_pyfunction!(stack, m)?)?;
Ok(())
}
diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py
index d63f752b..1d792de5 100644
--- a/candle-pyo3/test.py
+++ b/candle-pyo3/test.py
@@ -10,3 +10,4 @@ print(t)
print(t+t)
t = t.reshape([2, 4])
print(t.matmul(t.t()))
+print(t.to_dtype("u8"))