diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-02 20:59:21 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-02 19:59:21 +0100 |
commit | 84d003ff530eb14597f33c8c763eeb573370e22e (patch) | |
tree | 5bad6f9e6a974d8413feb5e2887a59374fb2fc2c /candle-pyo3 | |
parent | 21109e19834ad852e54daef7c7729b535e2241ba (diff) | |
download | candle-84d003ff530eb14597f33c8c763eeb573370e22e.tar.gz candle-84d003ff530eb14597f33c8c763eeb573370e22e.tar.bz2 candle-84d003ff530eb14597f33c8c763eeb573370e22e.zip |
Handle arbitrary shapes in Tensor::new. (#718)
Diffstat (limited to 'candle-pyo3')
-rw-r--r-- | candle-pyo3/src/lib.rs | 25 |
1 files changed, 20 insertions, 5 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 5e6f48ea..79f86479 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -205,14 +205,29 @@ impl PyTensor { Tensor::new(vs, &Cpu).map_err(wrap_err)? } else if let Ok(vs) = vs.extract::<i64>(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::<Vec<u32>>(py) { - Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::<Vec<i64>>(py) { - Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)? } else if let Ok(vs) = vs.extract::<f32>(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::<Vec<u32>>(py) { + let len = vs.len(); + Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::<Vec<i64>>(py) { + let len = vs.len(); + Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)? } else if let Ok(vs) = vs.extract::<Vec<f32>>(py) { - Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)? + let len = vs.len(); + Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::<Vec<Vec<u32>>>(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::<Vec<Vec<i64>>>(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::<Vec<Vec<f32>>>(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<u32>>>>(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<i64>>>>(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<f32>>>>(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? } else { let ty = vs.as_ref(py).get_type(); Err(PyTypeError::new_err(format!( |