diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-03-18 21:54:15 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-18 21:54:15 +0100 |
commit | 143c481c20abc3420e848eab075d1547a96cc447 (patch) | |
tree | f832da1cd35460151e7c0b6313f49c00e8e26053 /candle-pyo3/src/lib.rs | |
parent | f115895b9e981698daa04d0be33555c03f7892ed (diff) | |
download | candle-143c481c20abc3420e848eab075d1547a96cc447.tar.gz candle-143c481c20abc3420e848eab075d1547a96cc447.tar.bz2 candle-143c481c20abc3420e848eab075d1547a96cc447.zip |
Expose candle gather op in pyo3. (#1870)
Diffstat (limited to 'candle-pyo3/src/lib.rs')
-rw-r--r-- | candle-pyo3/src/lib.rs | 6 |
1 files changed, 6 insertions, 0 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 7b9a7413..e0d3bf30 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -448,6 +448,12 @@ impl PyTensor { Ok(PyTensor(self.0.index_select(rhs, dim).map_err(wrap_err)?)) } + /// Gathers values along an axis specified by dim. + fn gather(&self, index: &Self, dim: i64) -> PyResult<Self> { + let dim = actual_dim(self, dim).map_err(wrap_err)?; + Ok(PyTensor(self.0.gather(index, dim).map_err(wrap_err)?)) + } + #[pyo3(text_signature = "(self, rhs:Tensor)")] /// Performs a matrix multiplication between the two tensors. /// &RETURNS&: Tensor |