From 84d003ff530eb14597f33c8c763eeb573370e22e Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 2 Sep 2023 20:59:21 +0200 Subject: Handle arbitrary shapes in Tensor::new. (#718) --- candle-pyo3/src/lib.rs | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) (limited to 'candle-pyo3') diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 5e6f48ea..79f86479 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -205,14 +205,29 @@ impl PyTensor { Tensor::new(vs, &Cpu).map_err(wrap_err)? } else if let Ok(vs) = vs.extract::(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::>(py) { - Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::>(py) { - Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)? } else if let Ok(vs) = vs.extract::(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::>(py) { + let len = vs.len(); + Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::>(py) { + let len = vs.len(); + Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)? } else if let Ok(vs) = vs.extract::>(py) { - Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)? + let len = vs.len(); + Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::>>(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::>>(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::>>(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::>>>(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::>>>(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::>>>(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? } else { let ty = vs.as_ref(py).get_type(); Err(PyTypeError::new_err(format!( -- cgit v1.2.3