diff options
author | laurent <laurent.mazare@gmail.com> | 2023-07-02 07:07:22 +0100 |
---|---|---|
committer | laurent <laurent.mazare@gmail.com> | 2023-07-02 07:07:22 +0100 |
commit | c62cb73a7f7487ee76ba0d93c1cde7706b1025b3 (patch) | |
tree | 282e636bba4573864bb3227fa811a3e5955cf4d1 /candle-pyo3/src | |
parent | fa58c7643ded869a345b8df538dc38d21684ac88 (diff) | |
download | candle-c62cb73a7f7487ee76ba0d93c1cde7706b1025b3.tar.gz candle-c62cb73a7f7487ee76ba0d93c1cde7706b1025b3.tar.bz2 candle-c62cb73a7f7487ee76ba0d93c1cde7706b1025b3.zip |
Support higher order shapes for conversions.
Diffstat (limited to 'candle-pyo3/src')
-rw-r--r-- | candle-pyo3/src/lib.rs | 32 |
1 files changed, 31 insertions, 1 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 1d3e4efd..4328ac01 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -72,7 +72,37 @@ impl PyTensor { impl<'a> MapDType for M<'a> { type Output = PyObject; fn f<T: PyDType>(&self, t: &Tensor) -> PyResult<Self::Output> { - Ok(t.to_scalar::<T>().map_err(wrap_err)?.to_py(self.0)) + match t.rank() { + 0 => Ok(t.to_scalar::<T>().map_err(wrap_err)?.to_py(self.0)), + 1 => { + let v = t.to_vec1::<T>().map_err(wrap_err)?; + let v = v.iter().map(|v| v.to_py(self.0)).collect::<Vec<_>>(); + Ok(v.to_object(self.0)) + } + 2 => { + let v = t.to_vec2::<T>().map_err(wrap_err)?; + let v = v + .iter() + .map(|v| v.iter().map(|v| v.to_py(self.0)).collect()) + .collect::<Vec<Vec<_>>>(); + Ok(v.to_object(self.0)) + } + 3 => { + let v = t.to_vec3::<T>().map_err(wrap_err)?; + let v = v + .iter() + .map(|v| { + v.iter() + .map(|v| v.iter().map(|v| v.to_py(self.0)).collect()) + .collect() + }) + .collect::<Vec<Vec<Vec<_>>>>(); + Ok(v.to_object(self.0)) + } + n => Err(PyTypeError::new_err(format!( + "TODO: conversion to PyObject is not handled for rank {n}" + )))?, + } } } // TODO: Handle arbitrary shapes. |