diff options
Diffstat (limited to 'candle-pyo3/src')
-rw-r--r-- | candle-pyo3/src/lib.rs | 123 |
1 files changed, 105 insertions, 18 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 136f8a4f..1ff4db06 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1,3 +1,4 @@ +// TODO: Handle negative dimension indexes. use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::PyTuple; @@ -10,7 +11,23 @@ pub fn wrap_err(err: ::candle::Error) -> PyErr { PyErr::new::<PyValueError, _>(format!("{err:?}")) } -#[derive(Clone)] +#[derive(Clone, Debug)] +struct PyShape(Vec<usize>); + +impl<'source> pyo3::FromPyObject<'source> for PyShape { + fn extract(ob: &'source PyAny) -> PyResult<Self> { + let dims: Vec<usize> = pyo3::FromPyObject::extract(ob)?; + Ok(PyShape(dims)) + } +} + +impl From<PyShape> for ::candle::Shape { + fn from(val: PyShape) -> Self { + val.0.into() + } +} + +#[derive(Clone, Debug)] #[pyclass(name = "Tensor")] struct PyTensor(Tensor); @@ -23,21 +40,30 @@ impl std::ops::Deref for PyTensor { } #[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[pyclass(name = "DType")] struct PyDType(DType); -impl<'source> FromPyObject<'source> for PyDType { - fn extract(ob: &'source PyAny) -> PyResult<Self> { - use std::str::FromStr; - let dtype: &str = ob.extract()?; - let dtype = DType::from_str(dtype) - .map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?; - Ok(Self(dtype)) +#[pymethods] +impl PyDType { + fn __repr__(&self) -> String { + format!("{:?}", self.0) + } + + fn __str__(&self) -> String { + self.__repr__() } } -impl ToPyObject for PyDType { - fn to_object(&self, py: Python<'_>) -> PyObject { - self.0.as_str().to_object(py) +impl PyDType { + fn from_pyobject(ob: PyObject, py: Python<'_>) -> PyResult<Self> { + use std::str::FromStr; + if let Ok(dtype) = ob.extract::<&str>(py) { + let dtype = DType::from_str(dtype) + .map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?; + Ok(Self(dtype)) + } else { + ob.extract(py) + } } } @@ -206,8 +232,8 @@ impl PyTensor { } #[getter] - fn dtype(&self, py: Python<'_>) -> PyObject { - PyDType(self.0.dtype()).to_object(py) + fn dtype(&self) -> PyDType { + PyDType(self.0.dtype()) } #[getter] @@ -279,16 +305,15 @@ impl PyTensor { Ok(Self(tensor)) } - // TODO: Add a PyShape type? - fn reshape(&self, shape: Vec<usize>) -> PyResult<Self> { + fn reshape(&self, shape: PyShape) -> PyResult<Self> { Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?)) } - fn broadcast_as(&self, shape: Vec<usize>) -> PyResult<Self> { + fn broadcast_as(&self, shape: PyShape) -> PyResult<Self> { Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?)) } - fn broadcast_left(&self, shape: Vec<usize>) -> PyResult<Self> { + fn broadcast_left(&self, shape: PyShape) -> PyResult<Self> { Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?)) } @@ -351,7 +376,8 @@ impl PyTensor { Ok(PyTensor(self.0.copy().map_err(wrap_err)?)) } - fn to_dtype(&self, dtype: PyDType) -> PyResult<Self> { + fn to_dtype(&self, dtype: PyObject, py: Python<'_>) -> PyResult<Self> { + let dtype = PyDType::from_pyobject(dtype, py)?; Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?)) } @@ -381,11 +407,72 @@ fn tensor(py: Python<'_>, vs: PyObject) -> PyResult<PyTensor> { PyTensor::new(py, vs) } +#[pyfunction] +#[pyo3(signature = (shape, *, device=None))] +fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> { + let device = device.unwrap_or(PyDevice::Cpu).as_device()?; + let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?; + Ok(PyTensor(tensor)) +} + +#[pyfunction] +#[pyo3(signature = (shape, *, device=None))] +fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> { + let device = device.unwrap_or(PyDevice::Cpu).as_device()?; + let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?; + Ok(PyTensor(tensor)) +} + +#[pyfunction] +#[pyo3(signature = (shape, *, dtype=None, device=None))] +fn ones( + py: Python<'_>, + shape: PyShape, + dtype: Option<PyObject>, + device: Option<PyDevice>, +) -> PyResult<PyTensor> { + let dtype = match dtype { + None => DType::F32, + Some(dtype) => PyDType::from_pyobject(dtype, py)?.0, + }; + let device = device.unwrap_or(PyDevice::Cpu).as_device()?; + let tensor = Tensor::ones(shape.0, dtype, &device).map_err(wrap_err)?; + Ok(PyTensor(tensor)) +} + +#[pyfunction] +#[pyo3(signature = (shape, *, dtype=None, device=None))] +fn zeros( + py: Python<'_>, + shape: PyShape, + dtype: Option<PyObject>, + device: Option<PyDevice>, +) -> PyResult<PyTensor> { + let dtype = match dtype { + None => DType::F32, + Some(dtype) => PyDType::from_pyobject(dtype, py)?.0, + }; + let device = device.unwrap_or(PyDevice::Cpu).as_device()?; + let tensor = Tensor::zeros(shape.0, dtype, &device).map_err(wrap_err)?; + Ok(PyTensor(tensor)) +} + #[pymodule] fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_class::<PyTensor>()?; + m.add_class::<PyDType>()?; + m.add("u8", PyDType(DType::U8))?; + m.add("u32", PyDType(DType::U32))?; + m.add("bf16", PyDType(DType::BF16))?; + m.add("f16", PyDType(DType::F16))?; + m.add("f32", PyDType(DType::F32))?; + m.add("f64", PyDType(DType::F64))?; m.add_function(wrap_pyfunction!(cat, m)?)?; + m.add_function(wrap_pyfunction!(ones, m)?)?; + m.add_function(wrap_pyfunction!(rand, m)?)?; + m.add_function(wrap_pyfunction!(randn, m)?)?; m.add_function(wrap_pyfunction!(tensor, m)?)?; m.add_function(wrap_pyfunction!(stack, m)?)?; + m.add_function(wrap_pyfunction!(zeros, m)?)?; Ok(()) } |