summaryrefslogtreecommitdiff
path: root/candle-pyo3/src/lib.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-03-18 21:54:15 +0100
committerGitHub <noreply@github.com>2024-03-18 21:54:15 +0100
commit143c481c20abc3420e848eab075d1547a96cc447 (patch)
treef832da1cd35460151e7c0b6313f49c00e8e26053 /candle-pyo3/src/lib.rs
parentf115895b9e981698daa04d0be33555c03f7892ed (diff)
downloadcandle-143c481c20abc3420e848eab075d1547a96cc447.tar.gz
candle-143c481c20abc3420e848eab075d1547a96cc447.tar.bz2
candle-143c481c20abc3420e848eab075d1547a96cc447.zip
Expose candle gather op in pyo3. (#1870)
Diffstat (limited to 'candle-pyo3/src/lib.rs')
-rw-r--r--candle-pyo3/src/lib.rs6
1 files changed, 6 insertions, 0 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 7b9a7413..e0d3bf30 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -448,6 +448,12 @@ impl PyTensor {
Ok(PyTensor(self.0.index_select(rhs, dim).map_err(wrap_err)?))
}
+ /// Gathers values along an axis specified by dim.
+ fn gather(&self, index: &Self, dim: i64) -> PyResult<Self> {
+ let dim = actual_dim(self, dim).map_err(wrap_err)?;
+ Ok(PyTensor(self.0.gather(index, dim).map_err(wrap_err)?))
+ }
+
#[pyo3(text_signature = "(self, rhs:Tensor)")]
/// Performs a matrix multiplication between the two tensors.
/// &RETURNS&: Tensor