diff options
Diffstat (limited to 'candle-pyo3/src/lib.rs')
-rw-r--r-- | candle-pyo3/src/lib.rs | 56 |
1 files changed, 51 insertions, 5 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 43d99c25..5e6f48ea 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -145,6 +145,22 @@ pydtype!(bf16, f32::from); pydtype!(f32, |v| v); pydtype!(f64, |v| v); +fn actual_index(t: &Tensor, dim: usize, index: i64) -> ::candle::Result<usize> { + let dim = t.dim(dim)?; + if 0 <= index { + let index = index as usize; + if dim <= index { + ::candle::bail!("index {index} is too large for tensor dimension {dim}") + } + Ok(index) + } else { + if (dim as i64) < -index { + ::candle::bail!("index {index} is too low for tensor dimension {dim}") + } + Ok((dim as i64 + index) as usize) + } +} + fn actual_dim(t: &Tensor, dim: i64) -> ::candle::Result<usize> { let rank = t.rank(); if 0 <= dim { @@ -409,7 +425,8 @@ impl PyTensor { Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?)) } - fn squeeze(&self, dim: usize) -> PyResult<Self> { + fn squeeze(&self, dim: i64) -> PyResult<Self> { + let dim = actual_dim(self, dim).map_err(wrap_err)?; Ok(PyTensor(self.0.squeeze(dim).map_err(wrap_err)?)) } @@ -417,7 +434,8 @@ impl PyTensor { Ok(PyTensor(self.0.unsqueeze(dim).map_err(wrap_err)?)) } - fn get(&self, index: usize) -> PyResult<Self> { + fn get(&self, index: i64) -> PyResult<Self> { + let index = actual_index(self, 0, index).map_err(wrap_err)?; Ok(PyTensor(self.0.get(index).map_err(wrap_err)?)) } @@ -425,11 +443,32 @@ impl PyTensor { Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?)) } - fn narrow(&self, dim: i64, start: usize, len: usize) -> PyResult<Self> { + fn narrow(&self, dim: i64, start: i64, len: usize) -> PyResult<Self> { let dim = actual_dim(self, dim).map_err(wrap_err)?; + let start = actual_index(self, dim, start).map_err(wrap_err)?; Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?)) } + fn argmax_keepdim(&self, dim: i64) -> PyResult<Self> { + let dim = actual_dim(self, dim).map_err(wrap_err)?; + Ok(PyTensor(self.0.argmax_keepdim(dim).map_err(wrap_err)?)) + } + + fn argmin_keepdim(&self, dim: i64) -> PyResult<Self> { + let dim = actual_dim(self, dim).map_err(wrap_err)?; + Ok(PyTensor(self.0.argmin_keepdim(dim).map_err(wrap_err)?)) + } + + fn max_keepdim(&self, dim: i64) -> PyResult<Self> { + let dim = actual_dim(self, dim).map_err(wrap_err)?; + Ok(PyTensor(self.0.max_keepdim(dim).map_err(wrap_err)?)) + } + + fn min_keepdim(&self, dim: i64) -> PyResult<Self> { + let dim = actual_dim(self, dim).map_err(wrap_err)?; + Ok(PyTensor(self.0.min_keepdim(dim).map_err(wrap_err)?)) + } + fn sum_keepdim(&self, dims: PyObject, py: Python<'_>) -> PyResult<Self> { let dims = if let Ok(dim) = dims.extract::<usize>(py) { vec![dim] @@ -661,7 +700,7 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> { } #[pyfunction] -fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> { +fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> { let mut file = std::fs::File::open(path)?; let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?; let tensors = ggml @@ -681,7 +720,14 @@ fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> { ("ftype", ggml.hparams.ftype), ]; let hparams = hparams.into_py_dict(py).to_object(py); - Ok((tensors, hparams)) + let vocab = ggml + .vocab + .token_score_pairs + .iter() + .map(|(bytes, _)| String::from_utf8_lossy(bytes.as_slice()).to_string()) + .collect::<Vec<String>>() + .to_object(py); + Ok((tensors, hparams, vocab)) } #[pyfunction] |