summaryrefslogtreecommitdiff
path: root/candle-pyo3
diff options
context:
space:
mode:
Diffstat (limited to 'candle-pyo3')
-rw-r--r--candle-pyo3/src/lib.rs25
1 files changed, 20 insertions, 5 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 5e6f48ea..79f86479 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -205,14 +205,29 @@ impl PyTensor {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<i64>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
- } else if let Ok(vs) = vs.extract::<Vec<u32>>(py) {
- Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)?
- } else if let Ok(vs) = vs.extract::<Vec<i64>>(py) {
- Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<f32>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
+ } else if let Ok(vs) = vs.extract::<Vec<u32>>(py) {
+ let len = vs.len();
+ Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
+ } else if let Ok(vs) = vs.extract::<Vec<i64>>(py) {
+ let len = vs.len();
+ Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<f32>>(py) {
- Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)?
+ let len = vs.len();
+ Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
+ } else if let Ok(vs) = vs.extract::<Vec<Vec<u32>>>(py) {
+ Tensor::new(vs, &Cpu).map_err(wrap_err)?
+ } else if let Ok(vs) = vs.extract::<Vec<Vec<i64>>>(py) {
+ Tensor::new(vs, &Cpu).map_err(wrap_err)?
+ } else if let Ok(vs) = vs.extract::<Vec<Vec<f32>>>(py) {
+ Tensor::new(vs, &Cpu).map_err(wrap_err)?
+ } else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<u32>>>>(py) {
+ Tensor::new(vs, &Cpu).map_err(wrap_err)?
+ } else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<i64>>>>(py) {
+ Tensor::new(vs, &Cpu).map_err(wrap_err)?
+ } else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<f32>>>>(py) {
+ Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else {
let ty = vs.as_ref(py).get_type();
Err(PyTypeError::new_err(format!(