summaryrefslogtreecommitdiff
path: root/candle-pyo3/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-pyo3/src/lib.rs')
-rw-r--r--candle-pyo3/src/lib.rs134
1 files changed, 130 insertions, 4 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 64b6dd2c..4d4b5200 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -3,6 +3,7 @@ use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{IntoPyDict, PyDict, PyTuple};
use pyo3::ToPyObject;
+use std::os::raw::c_long;
use std::sync::Arc;
use half::{bf16, f16};
@@ -196,6 +197,12 @@ trait MapDType {
}
}
+enum Indexer {
+ Index(usize),
+ Slice(usize, usize),
+ Elipsis,
+}
+
#[pymethods]
impl PyTensor {
#[new]
@@ -436,6 +443,95 @@ impl PyTensor {
))
}
+ #[getter]
+ /// Index a tensor.
+ /// &RETURNS&: Tensor
+ fn __getitem__(&self, py: Python, idx: PyObject) -> PyResult<Self> {
+ let mut indexers: Vec<Indexer> = vec![];
+ let dims = self.0.shape().dims();
+
+ let to_absolute_index = |index: isize, current_dim: usize| {
+ // Convert a relative index to an absolute index e.g. tensor[-1] -> tensor[0]
+ let actual_index = if index < 0 {
+ dims[current_dim] as isize + index
+ } else {
+ index
+ };
+
+ // Check that the index is in range
+ if actual_index < 0 || actual_index >= dims[current_dim] as isize {
+ return Err(PyTypeError::new_err(format!(
+ "index out of range for dimension '{i}' with indexer '{value}'",
+ i = current_dim,
+ value = index
+ )));
+ }
+ Ok(actual_index as usize)
+ };
+ if let Ok(index) = idx.extract(py) {
+ // Handle a single index e.g. tensor[0] or tensor[-1]
+ indexers.push(Indexer::Index(to_absolute_index(index, 0)?));
+ } else if let Ok(slice) = idx.downcast::<pyo3::types::PySlice>(py) {
+ // Handle a single slice e.g. tensor[0:1] or tensor[0:-1]
+ let index = slice.indices(dims[0] as c_long)?;
+ indexers.push(Indexer::Slice(index.start as usize, index.stop as usize));
+ } else if let Ok(tuple) = idx.downcast::<pyo3::types::PyTuple>(py) {
+ // Handle multiple indices e.g. tensor[0,0] or tensor[0:1,0:1]
+
+ if tuple.len() > dims.len() {
+ return Err(PyTypeError::new_err("provided too many indices"));
+ }
+
+ for (i, item) in tuple.iter().enumerate() {
+ if item.is_ellipsis() {
+ // Handle '...' e.g. tensor[..., 0]
+
+ if i > 0 {
+ return Err(PyTypeError::new_err("Ellipsis ('...') can only be used at the start of an indexing operation"));
+ }
+ indexers.push(Indexer::Elipsis);
+ } else if let Ok(slice) = item.downcast::<pyo3::types::PySlice>() {
+ // Handle slice
+ let index = slice.indices(dims[i] as c_long)?;
+ indexers.push(Indexer::Slice(index.start as usize, index.stop as usize));
+ } else if let Ok(index) = item.extract::<isize>() {
+ indexers.push(Indexer::Index(to_absolute_index(index, i)?));
+ } else {
+ return Err(PyTypeError::new_err("unsupported index"));
+ }
+ }
+ } else {
+ return Err(PyTypeError::new_err("unsupported index"));
+ }
+
+ let mut x = self.0.clone();
+ let mut current_dim = 0;
+ // Apply the indexers
+ for indexer in indexers.iter() {
+ x = match indexer {
+ Indexer::Index(n) => x
+ .narrow(current_dim, *n, 1)
+ .map_err(wrap_err)?
+ .squeeze(current_dim)
+ .map_err(wrap_err)?,
+ Indexer::Slice(start, stop) => {
+ let out = x
+ .narrow(current_dim, *start, stop.saturating_sub(*start))
+ .map_err(wrap_err)?;
+ current_dim += 1;
+ out
+ }
+ Indexer::Elipsis => {
+ // Elipsis is a special case, it means that all remaining dimensions should be selected => advance the current_dim to the last dimension we have indexers for
+ current_dim += dims.len() - (indexers.len() - 1);
+ x
+ }
+ }
+ }
+
+ Ok(Self(x))
+ }
+
/// Add two tensors.
/// &RETURNS&: Tensor
fn __add__(&self, rhs: &PyAny) -> PyResult<Self> {
@@ -697,7 +793,7 @@ impl PyTensor {
/// &RETURNS&: QTensor
fn quantize(&self, quantized_dtype: &str) -> PyResult<PyQTensor> {
use ::candle::quantized;
- let res = match quantized_dtype {
+ let res = match quantized_dtype.to_lowercase().as_str() {
"q2k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ2K>(self),
"q3k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ3K>(self),
"q4_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4_0>(self),
@@ -1137,9 +1233,39 @@ fn silu(tensor: PyTensor) -> PyResult<PyTensor> {
Ok(PyTensor(s))
}
-fn candle_nn_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
+#[pyfunction]
+#[pyo3(text_signature = "(tensor:Tensor)")]
+/// Applies the Gaussian Error Linear Unit (GELU) function to a given tensor.
+/// &RETURNS&: Tensor
+fn gelu(tensor: PyTensor) -> PyResult<PyTensor> {
+ let s = tensor.0.gelu_erf().map_err(wrap_err)?;
+ Ok(PyTensor(s))
+}
+
+#[pyfunction]
+#[pyo3(text_signature = "(tensor:Tensor)")]
+/// Applies the Rectified Linear Unit (ReLU) function to a given tensor.
+/// &RETURNS&: Tensor
+fn relu(tensor: PyTensor) -> PyResult<PyTensor> {
+ let s = tensor.0.relu().map_err(wrap_err)?;
+ Ok(PyTensor(s))
+}
+
+#[pyfunction]
+#[pyo3(text_signature = "(tensor:Tensor)")]
+/// Applies the tanh function to a given tensor.
+/// &RETURNS&: Tensor
+fn tanh(tensor: PyTensor) -> PyResult<PyTensor> {
+ let s = tensor.0.tanh().map_err(wrap_err)?;
+ Ok(PyTensor(s))
+}
+
+fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(silu, m)?)?;
m.add_function(wrap_pyfunction!(softmax, m)?)?;
+ m.add_function(wrap_pyfunction!(gelu, m)?)?;
+ m.add_function(wrap_pyfunction!(relu, m)?)?;
+ m.add_function(wrap_pyfunction!(tanh, m)?)?;
Ok(())
}
@@ -1148,8 +1274,8 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let utils = PyModule::new(py, "utils")?;
candle_utils(py, utils)?;
m.add_submodule(utils)?;
- let nn = PyModule::new(py, "nn")?;
- candle_nn_m(py, nn)?;
+ let nn = PyModule::new(py, "functional")?;
+ candle_functional_m(py, nn)?;
m.add_submodule(nn)?;
m.add_class::<PyTensor>()?;
m.add_class::<PyQTensor>()?;