summaryrefslogtreecommitdiff
path: root/candle-pyo3
diff options
context:
space:
mode:
Diffstat (limited to 'candle-pyo3')
-rw-r--r--candle-pyo3/Cargo.toml2
-rw-r--r--candle-pyo3/src/lib.rs61
-rw-r--r--candle-pyo3/test.py17
3 files changed, 60 insertions, 20 deletions
diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml
index 610a1733..89263fe0 100644
--- a/candle-pyo3/Cargo.toml
+++ b/candle-pyo3/Cargo.toml
@@ -16,8 +16,8 @@ doc = false
[dependencies]
candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" }
-pyo3 = { version = "0.19.0", features = ["extension-module"] }
half = { workspace = true }
+pyo3 = { version = "0.19.0", features = ["extension-module"] }
[build-dependencies]
pyo3-build-config = "0.19"
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index fd013b9b..1ff4db06 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -40,21 +40,30 @@ impl std::ops::Deref for PyTensor {
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+#[pyclass(name = "DType")]
struct PyDType(DType);
-impl<'source> FromPyObject<'source> for PyDType {
- fn extract(ob: &'source PyAny) -> PyResult<Self> {
- use std::str::FromStr;
- let dtype: &str = ob.extract()?;
- let dtype = DType::from_str(dtype)
- .map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?;
- Ok(Self(dtype))
+#[pymethods]
+impl PyDType {
+ fn __repr__(&self) -> String {
+ format!("{:?}", self.0)
+ }
+
+ fn __str__(&self) -> String {
+ self.__repr__()
}
}
-impl ToPyObject for PyDType {
- fn to_object(&self, py: Python<'_>) -> PyObject {
- self.0.as_str().to_object(py)
+impl PyDType {
+ fn from_pyobject(ob: PyObject, py: Python<'_>) -> PyResult<Self> {
+ use std::str::FromStr;
+ if let Ok(dtype) = ob.extract::<&str>(py) {
+ let dtype = DType::from_str(dtype)
+ .map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?;
+ Ok(Self(dtype))
+ } else {
+ ob.extract(py)
+ }
}
}
@@ -223,8 +232,8 @@ impl PyTensor {
}
#[getter]
- fn dtype(&self, py: Python<'_>) -> PyObject {
- PyDType(self.0.dtype()).to_object(py)
+ fn dtype(&self) -> PyDType {
+ PyDType(self.0.dtype())
}
#[getter]
@@ -367,7 +376,8 @@ impl PyTensor {
Ok(PyTensor(self.0.copy().map_err(wrap_err)?))
}
- fn to_dtype(&self, dtype: PyDType) -> PyResult<Self> {
+ fn to_dtype(&self, dtype: PyObject, py: Python<'_>) -> PyResult<Self> {
+ let dtype = PyDType::from_pyobject(dtype, py)?;
Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?))
}
@@ -416,12 +426,15 @@ fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<
#[pyfunction]
#[pyo3(signature = (shape, *, dtype=None, device=None))]
fn ones(
- _py: Python<'_>,
+ py: Python<'_>,
shape: PyShape,
- dtype: Option<PyDType>,
+ dtype: Option<PyObject>,
device: Option<PyDevice>,
) -> PyResult<PyTensor> {
- let dtype = dtype.map_or(DType::F32, |dt| dt.0);
+ let dtype = match dtype {
+ None => DType::F32,
+ Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
+ };
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::ones(shape.0, dtype, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor))
@@ -430,12 +443,15 @@ fn ones(
#[pyfunction]
#[pyo3(signature = (shape, *, dtype=None, device=None))]
fn zeros(
- _py: Python<'_>,
+ py: Python<'_>,
shape: PyShape,
- dtype: Option<PyDType>,
+ dtype: Option<PyObject>,
device: Option<PyDevice>,
) -> PyResult<PyTensor> {
- let dtype = dtype.map_or(DType::F32, |dt| dt.0);
+ let dtype = match dtype {
+ None => DType::F32,
+ Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
+ };
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::zeros(shape.0, dtype, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor))
@@ -444,6 +460,13 @@ fn zeros(
#[pymodule]
fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<PyTensor>()?;
+ m.add_class::<PyDType>()?;
+ m.add("u8", PyDType(DType::U8))?;
+ m.add("u32", PyDType(DType::U32))?;
+ m.add("bf16", PyDType(DType::BF16))?;
+ m.add("f16", PyDType(DType::F16))?;
+ m.add("f32", PyDType(DType::F32))?;
+ m.add("f64", PyDType(DType::F64))?;
m.add_function(wrap_pyfunction!(cat, m)?)?;
m.add_function(wrap_pyfunction!(ones, m)?)?;
m.add_function(wrap_pyfunction!(rand, m)?)?;
diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py
index 160a099d..1711cdad 100644
--- a/candle-pyo3/test.py
+++ b/candle-pyo3/test.py
@@ -1,3 +1,18 @@
+import os
+import sys
+
+# The "import candle" statement below works if there is a "candle.so" file in sys.path.
+# Here we check for shared libraries that can be used in the build directory.
+BUILD_DIR = "./target/release-with-debug"
+so_file = BUILD_DIR + "/candle.so"
+if os.path.islink(so_file): os.remove(so_file)
+for lib_file in ["libcandle.dylib", "libcandle.so"]:
+ lib_file_ = BUILD_DIR + "/" + lib_file
+ if os.path.isfile(lib_file_):
+ os.symlink(lib_file, so_file)
+ sys.path.insert(0, BUILD_DIR)
+ break
+
import candle
t = candle.Tensor(42.0)
@@ -12,7 +27,9 @@ print(t+t)
t = t.reshape([2, 4])
print(t.matmul(t.t()))
+print(t.to_dtype(candle.u8))
print(t.to_dtype("u8"))
t = candle.randn((5, 3))
print(t)
+print(t.dtype)