diff options
Diffstat (limited to 'candle-pyo3/src/lib.rs')
-rw-r--r-- | candle-pyo3/src/lib.rs | 134 |
1 files changed, 130 insertions, 4 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 64b6dd2c..4d4b5200 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -3,6 +3,7 @@ use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::{IntoPyDict, PyDict, PyTuple}; use pyo3::ToPyObject; +use std::os::raw::c_long; use std::sync::Arc; use half::{bf16, f16}; @@ -196,6 +197,12 @@ trait MapDType { } } +enum Indexer { + Index(usize), + Slice(usize, usize), + Elipsis, +} + #[pymethods] impl PyTensor { #[new] @@ -436,6 +443,95 @@ impl PyTensor { )) } + #[getter] + /// Index a tensor. + /// &RETURNS&: Tensor + fn __getitem__(&self, py: Python, idx: PyObject) -> PyResult<Self> { + let mut indexers: Vec<Indexer> = vec![]; + let dims = self.0.shape().dims(); + + let to_absolute_index = |index: isize, current_dim: usize| { + // Convert a relative index to an absolute index e.g. tensor[-1] -> tensor[0] + let actual_index = if index < 0 { + dims[current_dim] as isize + index + } else { + index + }; + + // Check that the index is in range + if actual_index < 0 || actual_index >= dims[current_dim] as isize { + return Err(PyTypeError::new_err(format!( + "index out of range for dimension '{i}' with indexer '{value}'", + i = current_dim, + value = index + ))); + } + Ok(actual_index as usize) + }; + if let Ok(index) = idx.extract(py) { + // Handle a single index e.g. tensor[0] or tensor[-1] + indexers.push(Indexer::Index(to_absolute_index(index, 0)?)); + } else if let Ok(slice) = idx.downcast::<pyo3::types::PySlice>(py) { + // Handle a single slice e.g. tensor[0:1] or tensor[0:-1] + let index = slice.indices(dims[0] as c_long)?; + indexers.push(Indexer::Slice(index.start as usize, index.stop as usize)); + } else if let Ok(tuple) = idx.downcast::<pyo3::types::PyTuple>(py) { + // Handle multiple indices e.g. tensor[0,0] or tensor[0:1,0:1] + + if tuple.len() > dims.len() { + return Err(PyTypeError::new_err("provided too many indices")); + } + + for (i, item) in tuple.iter().enumerate() { + if item.is_ellipsis() { + // Handle '...' e.g. tensor[..., 0] + + if i > 0 { + return Err(PyTypeError::new_err("Ellipsis ('...') can only be used at the start of an indexing operation")); + } + indexers.push(Indexer::Elipsis); + } else if let Ok(slice) = item.downcast::<pyo3::types::PySlice>() { + // Handle slice + let index = slice.indices(dims[i] as c_long)?; + indexers.push(Indexer::Slice(index.start as usize, index.stop as usize)); + } else if let Ok(index) = item.extract::<isize>() { + indexers.push(Indexer::Index(to_absolute_index(index, i)?)); + } else { + return Err(PyTypeError::new_err("unsupported index")); + } + } + } else { + return Err(PyTypeError::new_err("unsupported index")); + } + + let mut x = self.0.clone(); + let mut current_dim = 0; + // Apply the indexers + for indexer in indexers.iter() { + x = match indexer { + Indexer::Index(n) => x + .narrow(current_dim, *n, 1) + .map_err(wrap_err)? + .squeeze(current_dim) + .map_err(wrap_err)?, + Indexer::Slice(start, stop) => { + let out = x + .narrow(current_dim, *start, stop.saturating_sub(*start)) + .map_err(wrap_err)?; + current_dim += 1; + out + } + Indexer::Elipsis => { + // Elipsis is a special case, it means that all remaining dimensions should be selected => advance the current_dim to the last dimension we have indexers for + current_dim += dims.len() - (indexers.len() - 1); + x + } + } + } + + Ok(Self(x)) + } + /// Add two tensors. /// &RETURNS&: Tensor fn __add__(&self, rhs: &PyAny) -> PyResult<Self> { @@ -697,7 +793,7 @@ impl PyTensor { /// &RETURNS&: QTensor fn quantize(&self, quantized_dtype: &str) -> PyResult<PyQTensor> { use ::candle::quantized; - let res = match quantized_dtype { + let res = match quantized_dtype.to_lowercase().as_str() { "q2k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ2K>(self), "q3k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ3K>(self), "q4_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4_0>(self), @@ -1137,9 +1233,39 @@ fn silu(tensor: PyTensor) -> PyResult<PyTensor> { Ok(PyTensor(s)) } -fn candle_nn_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> { +#[pyfunction] +#[pyo3(text_signature = "(tensor:Tensor)")] +/// Applies the Gaussian Error Linear Unit (GELU) function to a given tensor. +/// &RETURNS&: Tensor +fn gelu(tensor: PyTensor) -> PyResult<PyTensor> { + let s = tensor.0.gelu_erf().map_err(wrap_err)?; + Ok(PyTensor(s)) +} + +#[pyfunction] +#[pyo3(text_signature = "(tensor:Tensor)")] +/// Applies the Rectified Linear Unit (ReLU) function to a given tensor. +/// &RETURNS&: Tensor +fn relu(tensor: PyTensor) -> PyResult<PyTensor> { + let s = tensor.0.relu().map_err(wrap_err)?; + Ok(PyTensor(s)) +} + +#[pyfunction] +#[pyo3(text_signature = "(tensor:Tensor)")] +/// Applies the tanh function to a given tensor. +/// &RETURNS&: Tensor +fn tanh(tensor: PyTensor) -> PyResult<PyTensor> { + let s = tensor.0.tanh().map_err(wrap_err)?; + Ok(PyTensor(s)) +} + +fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(silu, m)?)?; m.add_function(wrap_pyfunction!(softmax, m)?)?; + m.add_function(wrap_pyfunction!(gelu, m)?)?; + m.add_function(wrap_pyfunction!(relu, m)?)?; + m.add_function(wrap_pyfunction!(tanh, m)?)?; Ok(()) } @@ -1148,8 +1274,8 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> { let utils = PyModule::new(py, "utils")?; candle_utils(py, utils)?; m.add_submodule(utils)?; - let nn = PyModule::new(py, "nn")?; - candle_nn_m(py, nn)?; + let nn = PyModule::new(py, "functional")?; + candle_functional_m(py, nn)?; m.add_submodule(nn)?; m.add_class::<PyTensor>()?; m.add_class::<PyQTensor>()?; |