diff options
Diffstat (limited to 'candle-pyo3')
-rw-r--r-- | candle-pyo3/Cargo.toml | 2 | ||||
-rw-r--r-- | candle-pyo3/src/lib.rs | 61 | ||||
-rw-r--r-- | candle-pyo3/test.py | 17 |
3 files changed, 60 insertions, 20 deletions
diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index 610a1733..89263fe0 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -16,8 +16,8 @@ doc = false [dependencies] candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" } -pyo3 = { version = "0.19.0", features = ["extension-module"] } half = { workspace = true } +pyo3 = { version = "0.19.0", features = ["extension-module"] } [build-dependencies] pyo3-build-config = "0.19" diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index fd013b9b..1ff4db06 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -40,21 +40,30 @@ impl std::ops::Deref for PyTensor { } #[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[pyclass(name = "DType")] struct PyDType(DType); -impl<'source> FromPyObject<'source> for PyDType { - fn extract(ob: &'source PyAny) -> PyResult<Self> { - use std::str::FromStr; - let dtype: &str = ob.extract()?; - let dtype = DType::from_str(dtype) - .map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?; - Ok(Self(dtype)) +#[pymethods] +impl PyDType { + fn __repr__(&self) -> String { + format!("{:?}", self.0) + } + + fn __str__(&self) -> String { + self.__repr__() } } -impl ToPyObject for PyDType { - fn to_object(&self, py: Python<'_>) -> PyObject { - self.0.as_str().to_object(py) +impl PyDType { + fn from_pyobject(ob: PyObject, py: Python<'_>) -> PyResult<Self> { + use std::str::FromStr; + if let Ok(dtype) = ob.extract::<&str>(py) { + let dtype = DType::from_str(dtype) + .map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?; + Ok(Self(dtype)) + } else { + ob.extract(py) + } } } @@ -223,8 +232,8 @@ impl PyTensor { } #[getter] - fn dtype(&self, py: Python<'_>) -> PyObject { - PyDType(self.0.dtype()).to_object(py) + fn dtype(&self) -> PyDType { + PyDType(self.0.dtype()) } #[getter] @@ -367,7 +376,8 @@ impl PyTensor { Ok(PyTensor(self.0.copy().map_err(wrap_err)?)) } - fn to_dtype(&self, dtype: PyDType) -> PyResult<Self> { + fn to_dtype(&self, dtype: PyObject, py: Python<'_>) -> PyResult<Self> { + let dtype = PyDType::from_pyobject(dtype, py)?; Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?)) } @@ -416,12 +426,15 @@ fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult< #[pyfunction] #[pyo3(signature = (shape, *, dtype=None, device=None))] fn ones( - _py: Python<'_>, + py: Python<'_>, shape: PyShape, - dtype: Option<PyDType>, + dtype: Option<PyObject>, device: Option<PyDevice>, ) -> PyResult<PyTensor> { - let dtype = dtype.map_or(DType::F32, |dt| dt.0); + let dtype = match dtype { + None => DType::F32, + Some(dtype) => PyDType::from_pyobject(dtype, py)?.0, + }; let device = device.unwrap_or(PyDevice::Cpu).as_device()?; let tensor = Tensor::ones(shape.0, dtype, &device).map_err(wrap_err)?; Ok(PyTensor(tensor)) @@ -430,12 +443,15 @@ fn ones( #[pyfunction] #[pyo3(signature = (shape, *, dtype=None, device=None))] fn zeros( - _py: Python<'_>, + py: Python<'_>, shape: PyShape, - dtype: Option<PyDType>, + dtype: Option<PyObject>, device: Option<PyDevice>, ) -> PyResult<PyTensor> { - let dtype = dtype.map_or(DType::F32, |dt| dt.0); + let dtype = match dtype { + None => DType::F32, + Some(dtype) => PyDType::from_pyobject(dtype, py)?.0, + }; let device = device.unwrap_or(PyDevice::Cpu).as_device()?; let tensor = Tensor::zeros(shape.0, dtype, &device).map_err(wrap_err)?; Ok(PyTensor(tensor)) @@ -444,6 +460,13 @@ fn zeros( #[pymodule] fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_class::<PyTensor>()?; + m.add_class::<PyDType>()?; + m.add("u8", PyDType(DType::U8))?; + m.add("u32", PyDType(DType::U32))?; + m.add("bf16", PyDType(DType::BF16))?; + m.add("f16", PyDType(DType::F16))?; + m.add("f32", PyDType(DType::F32))?; + m.add("f64", PyDType(DType::F64))?; m.add_function(wrap_pyfunction!(cat, m)?)?; m.add_function(wrap_pyfunction!(ones, m)?)?; m.add_function(wrap_pyfunction!(rand, m)?)?; diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py index 160a099d..1711cdad 100644 --- a/candle-pyo3/test.py +++ b/candle-pyo3/test.py @@ -1,3 +1,18 @@ +import os +import sys + +# The "import candle" statement below works if there is a "candle.so" file in sys.path. +# Here we check for shared libraries that can be used in the build directory. +BUILD_DIR = "./target/release-with-debug" +so_file = BUILD_DIR + "/candle.so" +if os.path.islink(so_file): os.remove(so_file) +for lib_file in ["libcandle.dylib", "libcandle.so"]: + lib_file_ = BUILD_DIR + "/" + lib_file + if os.path.isfile(lib_file_): + os.symlink(lib_file, so_file) + sys.path.insert(0, BUILD_DIR) + break + import candle t = candle.Tensor(42.0) @@ -12,7 +27,9 @@ print(t+t) t = t.reshape([2, 4]) print(t.matmul(t.t())) +print(t.to_dtype(candle.u8)) print(t.to_dtype("u8")) t = candle.randn((5, 3)) print(t) +print(t.dtype) |