diff options
Diffstat (limited to 'candle-pyo3')
-rw-r--r-- | candle-pyo3/src/lib.rs | 25 |
1 files changed, 20 insertions, 5 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 5e6f48ea..79f86479 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -205,14 +205,29 @@ impl PyTensor { Tensor::new(vs, &Cpu).map_err(wrap_err)? } else if let Ok(vs) = vs.extract::<i64>(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::<Vec<u32>>(py) { - Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::<Vec<i64>>(py) { - Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)? } else if let Ok(vs) = vs.extract::<f32>(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::<Vec<u32>>(py) { + let len = vs.len(); + Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::<Vec<i64>>(py) { + let len = vs.len(); + Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)? } else if let Ok(vs) = vs.extract::<Vec<f32>>(py) { - Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)? + let len = vs.len(); + Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::<Vec<Vec<u32>>>(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::<Vec<Vec<i64>>>(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::<Vec<Vec<f32>>>(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<u32>>>>(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<i64>>>>(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<f32>>>>(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? } else { let ty = vs.as_ref(py).get_type(); Err(PyTypeError::new_err(format!( |