summaryrefslogtreecommitdiff
path: root/candle-pyo3/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-pyo3/src/lib.rs')
-rw-r--r--candle-pyo3/src/lib.rs56
1 files changed, 51 insertions, 5 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 43d99c25..5e6f48ea 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -145,6 +145,22 @@ pydtype!(bf16, f32::from);
pydtype!(f32, |v| v);
pydtype!(f64, |v| v);
+fn actual_index(t: &Tensor, dim: usize, index: i64) -> ::candle::Result<usize> {
+ let dim = t.dim(dim)?;
+ if 0 <= index {
+ let index = index as usize;
+ if dim <= index {
+ ::candle::bail!("index {index} is too large for tensor dimension {dim}")
+ }
+ Ok(index)
+ } else {
+ if (dim as i64) < -index {
+ ::candle::bail!("index {index} is too low for tensor dimension {dim}")
+ }
+ Ok((dim as i64 + index) as usize)
+ }
+}
+
fn actual_dim(t: &Tensor, dim: i64) -> ::candle::Result<usize> {
let rank = t.rank();
if 0 <= dim {
@@ -409,7 +425,8 @@ impl PyTensor {
Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?))
}
- fn squeeze(&self, dim: usize) -> PyResult<Self> {
+ fn squeeze(&self, dim: i64) -> PyResult<Self> {
+ let dim = actual_dim(self, dim).map_err(wrap_err)?;
Ok(PyTensor(self.0.squeeze(dim).map_err(wrap_err)?))
}
@@ -417,7 +434,8 @@ impl PyTensor {
Ok(PyTensor(self.0.unsqueeze(dim).map_err(wrap_err)?))
}
- fn get(&self, index: usize) -> PyResult<Self> {
+ fn get(&self, index: i64) -> PyResult<Self> {
+ let index = actual_index(self, 0, index).map_err(wrap_err)?;
Ok(PyTensor(self.0.get(index).map_err(wrap_err)?))
}
@@ -425,11 +443,32 @@ impl PyTensor {
Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?))
}
- fn narrow(&self, dim: i64, start: usize, len: usize) -> PyResult<Self> {
+ fn narrow(&self, dim: i64, start: i64, len: usize) -> PyResult<Self> {
let dim = actual_dim(self, dim).map_err(wrap_err)?;
+ let start = actual_index(self, dim, start).map_err(wrap_err)?;
Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?))
}
+ fn argmax_keepdim(&self, dim: i64) -> PyResult<Self> {
+ let dim = actual_dim(self, dim).map_err(wrap_err)?;
+ Ok(PyTensor(self.0.argmax_keepdim(dim).map_err(wrap_err)?))
+ }
+
+ fn argmin_keepdim(&self, dim: i64) -> PyResult<Self> {
+ let dim = actual_dim(self, dim).map_err(wrap_err)?;
+ Ok(PyTensor(self.0.argmin_keepdim(dim).map_err(wrap_err)?))
+ }
+
+ fn max_keepdim(&self, dim: i64) -> PyResult<Self> {
+ let dim = actual_dim(self, dim).map_err(wrap_err)?;
+ Ok(PyTensor(self.0.max_keepdim(dim).map_err(wrap_err)?))
+ }
+
+ fn min_keepdim(&self, dim: i64) -> PyResult<Self> {
+ let dim = actual_dim(self, dim).map_err(wrap_err)?;
+ Ok(PyTensor(self.0.min_keepdim(dim).map_err(wrap_err)?))
+ }
+
fn sum_keepdim(&self, dims: PyObject, py: Python<'_>) -> PyResult<Self> {
let dims = if let Ok(dim) = dims.extract::<usize>(py) {
vec![dim]
@@ -661,7 +700,7 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
}
#[pyfunction]
-fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
+fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> {
let mut file = std::fs::File::open(path)?;
let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?;
let tensors = ggml
@@ -681,7 +720,14 @@ fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
("ftype", ggml.hparams.ftype),
];
let hparams = hparams.into_py_dict(py).to_object(py);
- Ok((tensors, hparams))
+ let vocab = ggml
+ .vocab
+ .token_score_pairs
+ .iter()
+ .map(|(bytes, _)| String::from_utf8_lossy(bytes.as_slice()).to_string())
+ .collect::<Vec<String>>()
+ .to_object(py);
+ Ok((tensors, hparams, vocab))
}
#[pyfunction]