summaryrefslogtreecommitdiff
path: root/candle-pyo3/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-pyo3/src/lib.rs')
-rw-r--r--candle-pyo3/src/lib.rs51
1 files changed, 31 insertions, 20 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 90826b98..ca406876 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -1074,20 +1074,20 @@ impl PyTensor {
fn quantize(&self, quantized_dtype: &str) -> PyResult<PyQTensor> {
use ::candle::quantized;
let res = match quantized_dtype.to_lowercase().as_str() {
- "q2k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ2K>(self),
- "q3k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ3K>(self),
- "q4_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4_0>(self),
- "q4_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4_1>(self),
- "q4k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4K>(self),
- "q5_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5_0>(self),
- "q5_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5_1>(self),
- "q5k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5K>(self),
- "q6k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ6K>(self),
- "q8_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8_0>(self),
- "q8_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8_1>(self),
- "q8k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8K>(self),
- "f16" => quantized::QTensor::quantize::<f16>(self),
- "f32" => quantized::QTensor::quantize::<f32>(self),
+ "q2k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q2K),
+ "q3k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q3K),
+ "q4_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4_0),
+ "q4_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4_1),
+ "q4k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4K),
+ "q5_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5_0),
+ "q5_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5_1),
+ "q5k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5K),
+ "q6k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q6K),
+ "q8_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8_0),
+ "q8_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8_1),
+ "q8k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8K),
+ "f16" => quantized::QTensor::quantize(self, quantized::GgmlDType::F16),
+ "f32" => quantized::QTensor::quantize(self, quantized::GgmlDType::F32),
dt => {
return Err(PyErr::new::<PyValueError, _>(format!(
"unknown quantized-dtype {dt}"
@@ -1278,13 +1278,19 @@ fn save_safetensors(
}
#[pyfunction]
-#[pyo3(text_signature = "(path:Union[str,PathLike])")]
+#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")]
/// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
/// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any], List[str]]
-fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> {
+fn load_ggml(
+ path: &str,
+ device: Option<PyDevice>,
+ py: Python<'_>,
+) -> PyResult<(PyObject, PyObject, PyObject)> {
let mut file = std::fs::File::open(path)?;
- let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?;
+ let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
+ let ggml =
+ ::candle::quantized::ggml_file::Content::read(&mut file, &device).map_err(wrap_err)?;
let tensors = ggml
.tensors
.into_iter()
@@ -1313,11 +1319,16 @@ fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObje
}
#[pyfunction]
-#[pyo3(text_signature = "(path:Union[str,PathLike])")]
+#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")]
/// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
/// and the second maps metadata keys to metadata values.
/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any]]
-fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
+fn load_gguf(
+ path: &str,
+ device: Option<PyDevice>,
+ py: Python<'_>,
+) -> PyResult<(PyObject, PyObject)> {
+ let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
use ::candle::quantized::gguf_file;
fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult<PyObject> {
let v: PyObject = match v {
@@ -1349,7 +1360,7 @@ fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
.tensor_infos
.keys()
.map(|key| {
- let qtensor = gguf.tensor(&mut file, key)?;
+ let qtensor = gguf.tensor(&mut file, key, &device)?;
Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py)))
})
.collect::<::candle::Result<Vec<_>>>()