summaryrefslogtreecommitdiff
path: root/candle-pyo3
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-02 20:59:21 +0200
committerGitHub <noreply@github.com>2023-09-02 19:59:21 +0100
commit84d003ff530eb14597f33c8c763eeb573370e22e (patch)
tree5bad6f9e6a974d8413feb5e2887a59374fb2fc2c /candle-pyo3
parent21109e19834ad852e54daef7c7729b535e2241ba (diff)
downloadcandle-84d003ff530eb14597f33c8c763eeb573370e22e.tar.gz
candle-84d003ff530eb14597f33c8c763eeb573370e22e.tar.bz2
candle-84d003ff530eb14597f33c8c763eeb573370e22e.zip
Handle arbitrary shapes in Tensor::new. (#718)
Diffstat (limited to 'candle-pyo3')
-rw-r--r--candle-pyo3/src/lib.rs25
1 files changed, 20 insertions, 5 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 5e6f48ea..79f86479 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -205,14 +205,29 @@ impl PyTensor {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<i64>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
- } else if let Ok(vs) = vs.extract::<Vec<u32>>(py) {
- Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)?
- } else if let Ok(vs) = vs.extract::<Vec<i64>>(py) {
- Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<f32>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
+ } else if let Ok(vs) = vs.extract::<Vec<u32>>(py) {
+ let len = vs.len();
+ Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
+ } else if let Ok(vs) = vs.extract::<Vec<i64>>(py) {
+ let len = vs.len();
+ Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<f32>>(py) {
- Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)?
+ let len = vs.len();
+ Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
+ } else if let Ok(vs) = vs.extract::<Vec<Vec<u32>>>(py) {
+ Tensor::new(vs, &Cpu).map_err(wrap_err)?
+ } else if let Ok(vs) = vs.extract::<Vec<Vec<i64>>>(py) {
+ Tensor::new(vs, &Cpu).map_err(wrap_err)?
+ } else if let Ok(vs) = vs.extract::<Vec<Vec<f32>>>(py) {
+ Tensor::new(vs, &Cpu).map_err(wrap_err)?
+ } else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<u32>>>>(py) {
+ Tensor::new(vs, &Cpu).map_err(wrap_err)?
+ } else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<i64>>>>(py) {
+ Tensor::new(vs, &Cpu).map_err(wrap_err)?
+ } else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<f32>>>>(py) {
+ Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else {
let ty = vs.as_ref(py).get_type();
Err(PyTypeError::new_err(format!(