summaryrefslogtreecommitdiff
path: root/candle-pyo3/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-pyo3/src')
-rw-r--r--candle-pyo3/src/lib.rs123
1 files changed, 105 insertions, 18 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 136f8a4f..1ff4db06 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -1,3 +1,4 @@
+// TODO: Handle negative dimension indexes.
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::PyTuple;
@@ -10,7 +11,23 @@ pub fn wrap_err(err: ::candle::Error) -> PyErr {
PyErr::new::<PyValueError, _>(format!("{err:?}"))
}
-#[derive(Clone)]
+#[derive(Clone, Debug)]
+struct PyShape(Vec<usize>);
+
+impl<'source> pyo3::FromPyObject<'source> for PyShape {
+ fn extract(ob: &'source PyAny) -> PyResult<Self> {
+ let dims: Vec<usize> = pyo3::FromPyObject::extract(ob)?;
+ Ok(PyShape(dims))
+ }
+}
+
+impl From<PyShape> for ::candle::Shape {
+ fn from(val: PyShape) -> Self {
+ val.0.into()
+ }
+}
+
+#[derive(Clone, Debug)]
#[pyclass(name = "Tensor")]
struct PyTensor(Tensor);
@@ -23,21 +40,30 @@ impl std::ops::Deref for PyTensor {
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+#[pyclass(name = "DType")]
struct PyDType(DType);
-impl<'source> FromPyObject<'source> for PyDType {
- fn extract(ob: &'source PyAny) -> PyResult<Self> {
- use std::str::FromStr;
- let dtype: &str = ob.extract()?;
- let dtype = DType::from_str(dtype)
- .map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?;
- Ok(Self(dtype))
+#[pymethods]
+impl PyDType {
+ fn __repr__(&self) -> String {
+ format!("{:?}", self.0)
+ }
+
+ fn __str__(&self) -> String {
+ self.__repr__()
}
}
-impl ToPyObject for PyDType {
- fn to_object(&self, py: Python<'_>) -> PyObject {
- self.0.as_str().to_object(py)
+impl PyDType {
+ fn from_pyobject(ob: PyObject, py: Python<'_>) -> PyResult<Self> {
+ use std::str::FromStr;
+ if let Ok(dtype) = ob.extract::<&str>(py) {
+ let dtype = DType::from_str(dtype)
+ .map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?;
+ Ok(Self(dtype))
+ } else {
+ ob.extract(py)
+ }
}
}
@@ -206,8 +232,8 @@ impl PyTensor {
}
#[getter]
- fn dtype(&self, py: Python<'_>) -> PyObject {
- PyDType(self.0.dtype()).to_object(py)
+ fn dtype(&self) -> PyDType {
+ PyDType(self.0.dtype())
}
#[getter]
@@ -279,16 +305,15 @@ impl PyTensor {
Ok(Self(tensor))
}
- // TODO: Add a PyShape type?
- fn reshape(&self, shape: Vec<usize>) -> PyResult<Self> {
+ fn reshape(&self, shape: PyShape) -> PyResult<Self> {
Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
}
- fn broadcast_as(&self, shape: Vec<usize>) -> PyResult<Self> {
+ fn broadcast_as(&self, shape: PyShape) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?))
}
- fn broadcast_left(&self, shape: Vec<usize>) -> PyResult<Self> {
+ fn broadcast_left(&self, shape: PyShape) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?))
}
@@ -351,7 +376,8 @@ impl PyTensor {
Ok(PyTensor(self.0.copy().map_err(wrap_err)?))
}
- fn to_dtype(&self, dtype: PyDType) -> PyResult<Self> {
+ fn to_dtype(&self, dtype: PyObject, py: Python<'_>) -> PyResult<Self> {
+ let dtype = PyDType::from_pyobject(dtype, py)?;
Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?))
}
@@ -381,11 +407,72 @@ fn tensor(py: Python<'_>, vs: PyObject) -> PyResult<PyTensor> {
PyTensor::new(py, vs)
}
+#[pyfunction]
+#[pyo3(signature = (shape, *, device=None))]
+fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
+ let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
+ let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
+ Ok(PyTensor(tensor))
+}
+
+#[pyfunction]
+#[pyo3(signature = (shape, *, device=None))]
+fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
+ let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
+ let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
+ Ok(PyTensor(tensor))
+}
+
+#[pyfunction]
+#[pyo3(signature = (shape, *, dtype=None, device=None))]
+fn ones(
+ py: Python<'_>,
+ shape: PyShape,
+ dtype: Option<PyObject>,
+ device: Option<PyDevice>,
+) -> PyResult<PyTensor> {
+ let dtype = match dtype {
+ None => DType::F32,
+ Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
+ };
+ let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
+ let tensor = Tensor::ones(shape.0, dtype, &device).map_err(wrap_err)?;
+ Ok(PyTensor(tensor))
+}
+
+#[pyfunction]
+#[pyo3(signature = (shape, *, dtype=None, device=None))]
+fn zeros(
+ py: Python<'_>,
+ shape: PyShape,
+ dtype: Option<PyObject>,
+ device: Option<PyDevice>,
+) -> PyResult<PyTensor> {
+ let dtype = match dtype {
+ None => DType::F32,
+ Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
+ };
+ let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
+ let tensor = Tensor::zeros(shape.0, dtype, &device).map_err(wrap_err)?;
+ Ok(PyTensor(tensor))
+}
+
#[pymodule]
fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<PyTensor>()?;
+ m.add_class::<PyDType>()?;
+ m.add("u8", PyDType(DType::U8))?;
+ m.add("u32", PyDType(DType::U32))?;
+ m.add("bf16", PyDType(DType::BF16))?;
+ m.add("f16", PyDType(DType::F16))?;
+ m.add("f32", PyDType(DType::F32))?;
+ m.add("f64", PyDType(DType::F64))?;
m.add_function(wrap_pyfunction!(cat, m)?)?;
+ m.add_function(wrap_pyfunction!(ones, m)?)?;
+ m.add_function(wrap_pyfunction!(rand, m)?)?;
+ m.add_function(wrap_pyfunction!(randn, m)?)?;
m.add_function(wrap_pyfunction!(tensor, m)?)?;
m.add_function(wrap_pyfunction!(stack, m)?)?;
+ m.add_function(wrap_pyfunction!(zeros, m)?)?;
Ok(())
}