summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-pyo3/py_src/candle/__init__.pyi6
-rw-r--r--candle-pyo3/src/lib.rs6
2 files changed, 12 insertions, 0 deletions
diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi
index aef0707d..b0f05de5 100644
--- a/candle-pyo3/py_src/candle/__init__.pyi
+++ b/candle-pyo3/py_src/candle/__init__.pyi
@@ -324,6 +324,12 @@ class Tensor:
"""
pass
+ def gather(self, index, dim):
+ """
+ Gathers values along an axis specified by dim.
+ """
+ pass
+
def get(self, index: int) -> Tensor:
"""
Gets the value at the specified index.
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 7b9a7413..e0d3bf30 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -448,6 +448,12 @@ impl PyTensor {
Ok(PyTensor(self.0.index_select(rhs, dim).map_err(wrap_err)?))
}
+ /// Gathers values along an axis specified by dim.
+ fn gather(&self, index: &Self, dim: i64) -> PyResult<Self> {
+ let dim = actual_dim(self, dim).map_err(wrap_err)?;
+ Ok(PyTensor(self.0.gather(index, dim).map_err(wrap_err)?))
+ }
+
#[pyo3(text_signature = "(self, rhs:Tensor)")]
/// Performs a matrix multiplication between the two tensors.
/// &RETURNS&: Tensor