summaryrefslogtreecommitdiff
path: root/candle-pyo3/src
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2023-07-02 07:07:22 +0100
committerlaurent <laurent.mazare@gmail.com>2023-07-02 07:07:22 +0100
commitc62cb73a7f7487ee76ba0d93c1cde7706b1025b3 (patch)
tree282e636bba4573864bb3227fa811a3e5955cf4d1 /candle-pyo3/src
parentfa58c7643ded869a345b8df538dc38d21684ac88 (diff)
downloadcandle-c62cb73a7f7487ee76ba0d93c1cde7706b1025b3.tar.gz
candle-c62cb73a7f7487ee76ba0d93c1cde7706b1025b3.tar.bz2
candle-c62cb73a7f7487ee76ba0d93c1cde7706b1025b3.zip
Support higher order shapes for conversions.
Diffstat (limited to 'candle-pyo3/src')
-rw-r--r--candle-pyo3/src/lib.rs32
1 files changed, 31 insertions, 1 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 1d3e4efd..4328ac01 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -72,7 +72,37 @@ impl PyTensor {
impl<'a> MapDType for M<'a> {
type Output = PyObject;
fn f<T: PyDType>(&self, t: &Tensor) -> PyResult<Self::Output> {
- Ok(t.to_scalar::<T>().map_err(wrap_err)?.to_py(self.0))
+ match t.rank() {
+ 0 => Ok(t.to_scalar::<T>().map_err(wrap_err)?.to_py(self.0)),
+ 1 => {
+ let v = t.to_vec1::<T>().map_err(wrap_err)?;
+ let v = v.iter().map(|v| v.to_py(self.0)).collect::<Vec<_>>();
+ Ok(v.to_object(self.0))
+ }
+ 2 => {
+ let v = t.to_vec2::<T>().map_err(wrap_err)?;
+ let v = v
+ .iter()
+ .map(|v| v.iter().map(|v| v.to_py(self.0)).collect())
+ .collect::<Vec<Vec<_>>>();
+ Ok(v.to_object(self.0))
+ }
+ 3 => {
+ let v = t.to_vec3::<T>().map_err(wrap_err)?;
+ let v = v
+ .iter()
+ .map(|v| {
+ v.iter()
+ .map(|v| v.iter().map(|v| v.to_py(self.0)).collect())
+ .collect()
+ })
+ .collect::<Vec<Vec<Vec<_>>>>();
+ Ok(v.to_object(self.0))
+ }
+ n => Err(PyTypeError::new_err(format!(
+ "TODO: conversion to PyObject is not handled for rank {n}"
+ )))?,
+ }
}
}
// TODO: Handle arbitrary shapes.