diff options
Diffstat (limited to 'candle-pyo3')
-rw-r--r-- | candle-pyo3/Cargo.toml | 5 | ||||
-rw-r--r-- | candle-pyo3/README.md | 10 | ||||
-rw-r--r-- | candle-pyo3/build.rs | 3 | ||||
-rw-r--r-- | candle-pyo3/src/lib.rs | 123 | ||||
-rw-r--r-- | candle-pyo3/test.py | 20 |
5 files changed, 140 insertions, 21 deletions
diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index e5ebe953..89263fe0 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -16,8 +16,11 @@ doc = false [dependencies] candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" } -pyo3 = { version = "0.19.0", features = ["extension-module"] } half = { workspace = true } +pyo3 = { version = "0.19.0", features = ["extension-module"] } + +[build-dependencies] +pyo3-build-config = "0.19" [features] default = [] diff --git a/candle-pyo3/README.md b/candle-pyo3/README.md index 1887f269..f716b092 100644 --- a/candle-pyo3/README.md +++ b/candle-pyo3/README.md @@ -1,5 +1,11 @@ -From the top level directory run: +From the top level directory run the following for linux. ``` -cargo build --release --package candle-pyo3 && cp -f ./target/release/libcandle.so candle.so +cargo build --profile=release-with-debug --package candle-pyo3 && cp -f ./target/release-with-debug/libcandle.so candle.so +PYTHONPATH=. python3 candle-pyo3/test.py +```bash + + Or for macOS users: +```bash +cargo build --profile=release-with-debug --package candle-pyo3 && cp -f ./target/release-with-debug/libcandle.dylib candle.so PYTHONPATH=. python3 candle-pyo3/test.py ``` diff --git a/candle-pyo3/build.rs b/candle-pyo3/build.rs new file mode 100644 index 00000000..dace4a9b --- /dev/null +++ b/candle-pyo3/build.rs @@ -0,0 +1,3 @@ +fn main() { + pyo3_build_config::add_extension_module_link_args(); +} diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 136f8a4f..1ff4db06 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1,3 +1,4 @@ +// TODO: Handle negative dimension indexes. use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::PyTuple; @@ -10,7 +11,23 @@ pub fn wrap_err(err: ::candle::Error) -> PyErr { PyErr::new::<PyValueError, _>(format!("{err:?}")) } -#[derive(Clone)] +#[derive(Clone, Debug)] +struct PyShape(Vec<usize>); + +impl<'source> pyo3::FromPyObject<'source> for PyShape { + fn extract(ob: &'source PyAny) -> PyResult<Self> { + let dims: Vec<usize> = pyo3::FromPyObject::extract(ob)?; + Ok(PyShape(dims)) + } +} + +impl From<PyShape> for ::candle::Shape { + fn from(val: PyShape) -> Self { + val.0.into() + } +} + +#[derive(Clone, Debug)] #[pyclass(name = "Tensor")] struct PyTensor(Tensor); @@ -23,21 +40,30 @@ impl std::ops::Deref for PyTensor { } #[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[pyclass(name = "DType")] struct PyDType(DType); -impl<'source> FromPyObject<'source> for PyDType { - fn extract(ob: &'source PyAny) -> PyResult<Self> { - use std::str::FromStr; - let dtype: &str = ob.extract()?; - let dtype = DType::from_str(dtype) - .map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?; - Ok(Self(dtype)) +#[pymethods] +impl PyDType { + fn __repr__(&self) -> String { + format!("{:?}", self.0) + } + + fn __str__(&self) -> String { + self.__repr__() } } -impl ToPyObject for PyDType { - fn to_object(&self, py: Python<'_>) -> PyObject { - self.0.as_str().to_object(py) +impl PyDType { + fn from_pyobject(ob: PyObject, py: Python<'_>) -> PyResult<Self> { + use std::str::FromStr; + if let Ok(dtype) = ob.extract::<&str>(py) { + let dtype = DType::from_str(dtype) + .map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?; + Ok(Self(dtype)) + } else { + ob.extract(py) + } } } @@ -206,8 +232,8 @@ impl PyTensor { } #[getter] - fn dtype(&self, py: Python<'_>) -> PyObject { - PyDType(self.0.dtype()).to_object(py) + fn dtype(&self) -> PyDType { + PyDType(self.0.dtype()) } #[getter] @@ -279,16 +305,15 @@ impl PyTensor { Ok(Self(tensor)) } - // TODO: Add a PyShape type? - fn reshape(&self, shape: Vec<usize>) -> PyResult<Self> { + fn reshape(&self, shape: PyShape) -> PyResult<Self> { Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?)) } - fn broadcast_as(&self, shape: Vec<usize>) -> PyResult<Self> { + fn broadcast_as(&self, shape: PyShape) -> PyResult<Self> { Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?)) } - fn broadcast_left(&self, shape: Vec<usize>) -> PyResult<Self> { + fn broadcast_left(&self, shape: PyShape) -> PyResult<Self> { Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?)) } @@ -351,7 +376,8 @@ impl PyTensor { Ok(PyTensor(self.0.copy().map_err(wrap_err)?)) } - fn to_dtype(&self, dtype: PyDType) -> PyResult<Self> { + fn to_dtype(&self, dtype: PyObject, py: Python<'_>) -> PyResult<Self> { + let dtype = PyDType::from_pyobject(dtype, py)?; Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?)) } @@ -381,11 +407,72 @@ fn tensor(py: Python<'_>, vs: PyObject) -> PyResult<PyTensor> { PyTensor::new(py, vs) } +#[pyfunction] +#[pyo3(signature = (shape, *, device=None))] +fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> { + let device = device.unwrap_or(PyDevice::Cpu).as_device()?; + let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?; + Ok(PyTensor(tensor)) +} + +#[pyfunction] +#[pyo3(signature = (shape, *, device=None))] +fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> { + let device = device.unwrap_or(PyDevice::Cpu).as_device()?; + let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?; + Ok(PyTensor(tensor)) +} + +#[pyfunction] +#[pyo3(signature = (shape, *, dtype=None, device=None))] +fn ones( + py: Python<'_>, + shape: PyShape, + dtype: Option<PyObject>, + device: Option<PyDevice>, +) -> PyResult<PyTensor> { + let dtype = match dtype { + None => DType::F32, + Some(dtype) => PyDType::from_pyobject(dtype, py)?.0, + }; + let device = device.unwrap_or(PyDevice::Cpu).as_device()?; + let tensor = Tensor::ones(shape.0, dtype, &device).map_err(wrap_err)?; + Ok(PyTensor(tensor)) +} + +#[pyfunction] +#[pyo3(signature = (shape, *, dtype=None, device=None))] +fn zeros( + py: Python<'_>, + shape: PyShape, + dtype: Option<PyObject>, + device: Option<PyDevice>, +) -> PyResult<PyTensor> { + let dtype = match dtype { + None => DType::F32, + Some(dtype) => PyDType::from_pyobject(dtype, py)?.0, + }; + let device = device.unwrap_or(PyDevice::Cpu).as_device()?; + let tensor = Tensor::zeros(shape.0, dtype, &device).map_err(wrap_err)?; + Ok(PyTensor(tensor)) +} + #[pymodule] fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_class::<PyTensor>()?; + m.add_class::<PyDType>()?; + m.add("u8", PyDType(DType::U8))?; + m.add("u32", PyDType(DType::U32))?; + m.add("bf16", PyDType(DType::BF16))?; + m.add("f16", PyDType(DType::F16))?; + m.add("f32", PyDType(DType::F32))?; + m.add("f64", PyDType(DType::F64))?; m.add_function(wrap_pyfunction!(cat, m)?)?; + m.add_function(wrap_pyfunction!(ones, m)?)?; + m.add_function(wrap_pyfunction!(rand, m)?)?; + m.add_function(wrap_pyfunction!(randn, m)?)?; m.add_function(wrap_pyfunction!(tensor, m)?)?; m.add_function(wrap_pyfunction!(stack, m)?)?; + m.add_function(wrap_pyfunction!(zeros, m)?)?; Ok(()) } diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py index 8f906060..1711cdad 100644 --- a/candle-pyo3/test.py +++ b/candle-pyo3/test.py @@ -1,3 +1,18 @@ +import os +import sys + +# The "import candle" statement below works if there is a "candle.so" file in sys.path. +# Here we check for shared libraries that can be used in the build directory. +BUILD_DIR = "./target/release-with-debug" +so_file = BUILD_DIR + "/candle.so" +if os.path.islink(so_file): os.remove(so_file) +for lib_file in ["libcandle.dylib", "libcandle.so"]: + lib_file_ = BUILD_DIR + "/" + lib_file + if os.path.isfile(lib_file_): + os.symlink(lib_file, so_file) + sys.path.insert(0, BUILD_DIR) + break + import candle t = candle.Tensor(42.0) @@ -12,4 +27,9 @@ print(t+t) t = t.reshape([2, 4]) print(t.matmul(t.t())) +print(t.to_dtype(candle.u8)) print(t.to_dtype("u8")) + +t = candle.randn((5, 3)) +print(t) +print(t.dtype) |