diff options
Diffstat (limited to 'candle-pyo3/src/lib.rs')
-rw-r--r-- | candle-pyo3/src/lib.rs | 111 |
1 files changed, 105 insertions, 6 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 2673d843..43d99c25 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -145,6 +145,22 @@ pydtype!(bf16, f32::from); pydtype!(f32, |v| v); pydtype!(f64, |v| v); +fn actual_dim(t: &Tensor, dim: i64) -> ::candle::Result<usize> { + let rank = t.rank(); + if 0 <= dim { + let dim = dim as usize; + if rank <= dim { + ::candle::bail!("dimension index {dim} is too large for tensor rank {rank}") + } + Ok(dim) + } else { + if (rank as i64) < -dim { + ::candle::bail!("dimension index {dim} is too low for tensor rank {rank}") + } + Ok((rank as i64 + dim) as usize) + } +} + // TODO: Something similar to this should probably be a part of candle core. trait MapDType { type Output; @@ -182,7 +198,10 @@ impl PyTensor { } else if let Ok(vs) = vs.extract::<Vec<f32>>(py) { Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)? } else { - Err(PyTypeError::new_err("incorrect type for tensor"))? + let ty = vs.as_ref(py).get_type(); + Err(PyTypeError::new_err(format!( + "incorrect type {ty} for tensor" + )))? }; Ok(Self(tensor)) } @@ -295,10 +314,31 @@ impl PyTensor { Ok(PyTensor(self.0.powf(p).map_err(wrap_err)?)) } + fn index_select(&self, rhs: &Self, dim: i64) -> PyResult<Self> { + let dim = actual_dim(self, dim).map_err(wrap_err)?; + Ok(PyTensor(self.0.index_select(rhs, dim).map_err(wrap_err)?)) + } + fn matmul(&self, rhs: &Self) -> PyResult<Self> { Ok(PyTensor(self.0.matmul(rhs).map_err(wrap_err)?)) } + fn broadcast_add(&self, rhs: &Self) -> PyResult<Self> { + Ok(PyTensor(self.0.broadcast_add(rhs).map_err(wrap_err)?)) + } + + fn broadcast_sub(&self, rhs: &Self) -> PyResult<Self> { + Ok(PyTensor(self.0.broadcast_sub(rhs).map_err(wrap_err)?)) + } + + fn broadcast_mul(&self, rhs: &Self) -> PyResult<Self> { + Ok(PyTensor(self.0.broadcast_mul(rhs).map_err(wrap_err)?)) + } + + fn broadcast_div(&self, rhs: &Self) -> PyResult<Self> { + Ok(PyTensor(self.0.broadcast_div(rhs).map_err(wrap_err)?)) + } + fn where_cond(&self, on_true: &Self, on_false: &Self) -> PyResult<Self> { Ok(PyTensor( self.0.where_cond(on_true, on_false).map_err(wrap_err)?, @@ -346,6 +386,17 @@ impl PyTensor { Ok(Self(tensor)) } + fn __truediv__(&self, rhs: &PyAny) -> PyResult<Self> { + let tensor = if let Ok(rhs) = rhs.extract::<Self>() { + (&self.0 / &rhs.0).map_err(wrap_err)? + } else if let Ok(rhs) = rhs.extract::<f64>() { + (&self.0 / rhs).map_err(wrap_err)? + } else { + Err(PyTypeError::new_err("unsupported rhs for div"))? + }; + Ok(Self(tensor)) + } + fn reshape(&self, shape: PyShape) -> PyResult<Self> { Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?)) } @@ -374,7 +425,8 @@ impl PyTensor { Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?)) } - fn narrow(&self, dim: usize, start: usize, len: usize) -> PyResult<Self> { + fn narrow(&self, dim: i64, start: usize, len: usize) -> PyResult<Self> { + let dim = actual_dim(self, dim).map_err(wrap_err)?; Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?)) } @@ -400,6 +452,16 @@ impl PyTensor { Ok(PyTensor(mean)) } + fn flatten_from(&self, dim: i64) -> PyResult<Self> { + let dim = actual_dim(self, dim).map_err(wrap_err)?; + Ok(PyTensor(self.0.flatten_from(dim).map_err(wrap_err)?)) + } + + fn flatten_to(&self, dim: i64) -> PyResult<Self> { + let dim = actual_dim(self, dim).map_err(wrap_err)?; + Ok(PyTensor(self.0.flatten_to(dim).map_err(wrap_err)?)) + } + fn flatten_all(&self) -> PyResult<Self> { Ok(PyTensor(self.0.flatten_all().map_err(wrap_err)?)) } @@ -467,7 +529,11 @@ impl PyTensor { /// Concatenate the tensors across one axis. #[pyfunction] -fn cat(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> { +fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> { + if tensors.is_empty() { + return Err(PyErr::new::<PyValueError, _>("empty input to cat")); + } + let dim = actual_dim(&tensors[0], dim).map_err(wrap_err)?; let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>(); let tensor = Tensor::cat(&tensors, dim).map_err(wrap_err)?; Ok(PyTensor(tensor)) @@ -595,16 +661,27 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> { } #[pyfunction] -fn load_ggml(path: &str, py: Python<'_>) -> PyResult<PyObject> { +fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> { let mut file = std::fs::File::open(path)?; let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?; - let res = ggml + let tensors = ggml .tensors .into_iter() .map(|(key, qtensor)| Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py)))) .collect::<::candle::Result<Vec<_>>>() .map_err(wrap_err)?; - Ok(res.into_py_dict(py).to_object(py)) + let tensors = tensors.into_py_dict(py).to_object(py); + let hparams = [ + ("n_vocab", ggml.hparams.n_vocab), + ("n_embd", ggml.hparams.n_embd), + ("n_mult", ggml.hparams.n_mult), + ("n_head", ggml.hparams.n_head), + ("n_layer", ggml.hparams.n_layer), + ("n_rot", ggml.hparams.n_rot), + ("ftype", ggml.hparams.ftype), + ]; + let hparams = hparams.into_py_dict(py).to_object(py); + Ok((tensors, hparams)) } #[pyfunction] @@ -651,11 +728,33 @@ fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> { Ok(()) } +#[pyfunction] +fn softmax(t: PyTensor, dim: i64) -> PyResult<PyTensor> { + let dim = actual_dim(&t, dim).map_err(wrap_err)?; + let sm = candle_nn::ops::softmax(&t.0, dim).map_err(wrap_err)?; + Ok(PyTensor(sm)) +} + +#[pyfunction] +fn silu(t: PyTensor) -> PyResult<PyTensor> { + let s = candle_nn::ops::silu(&t.0).map_err(wrap_err)?; + Ok(PyTensor(s)) +} + +fn candle_nn_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> { + m.add_function(wrap_pyfunction!(silu, m)?)?; + m.add_function(wrap_pyfunction!(softmax, m)?)?; + Ok(()) +} + #[pymodule] fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> { let utils = PyModule::new(py, "utils")?; candle_utils(py, utils)?; m.add_submodule(utils)?; + let nn = PyModule::new(py, "nn")?; + candle_nn_m(py, nn)?; + m.add_submodule(nn)?; m.add_class::<PyTensor>()?; m.add_class::<PyQTensor>()?; m.add_class::<PyDType>()?; |