diff options
author | Nicolas Patry <patry.nicolas@protonmail.com> | 2024-01-17 10:27:58 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-17 10:27:58 +0100 |
commit | 403680f17ddc086295fbaee316cbed22d97a519b (patch) | |
tree | 80dcffe6e929640e7f0ebfff3ba90410fd58992e /candle-pyo3/src | |
parent | 5270224f407502b82fe90bc2622894ce3871b002 (diff) | |
download | candle-403680f17ddc086295fbaee316cbed22d97a519b.tar.gz candle-403680f17ddc086295fbaee316cbed22d97a519b.tar.bz2 candle-403680f17ddc086295fbaee316cbed22d97a519b.zip |
Quantized GGUF style (#1523)
* Metal quantized modifications proposal.
- Add a device param, wherever needed.
- Create new QMetal storage thing that implements QuantizedType.
- Update everywhere needed.
Fix Python.
Fixing examples.
Fix: fmt + clippy + stub.
Moving everything around.
Only missing the actual implems.
Fixing everything + adding dequantized kernels.
More work.
Fixing matmul.
Fmt + Clippy
Some clippy fixes.
Working state.
Q2K Metal -> Bugged (also present in GGML).
Q4K CPU -> Bugged (present previously, new test catch it).
Q5K CPU -> Bugged (present previously).
Q8_1 Both -> Never really implemented it seems
Q8K metal -> Never implemented in metal
Fixing Q2K bug (present in ggml).
* Cleanup.
* Fix the rebase.
* Removing the fences speeds everything up and *is* correct this time...
* Cleanup the fence.
* After rebase.
* Bad code removal.
* Rebase after phi2 merge + fix replit default to CPU.
* Making the CI happy.
* More happy tests.
---------
Co-authored-by: Nicolas Patry <nicolas@Nicolass-MacBook-Pro.local>
Diffstat (limited to 'candle-pyo3/src')
-rw-r--r-- | candle-pyo3/src/lib.rs | 51 |
1 files changed, 31 insertions, 20 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 90826b98..ca406876 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1074,20 +1074,20 @@ impl PyTensor { fn quantize(&self, quantized_dtype: &str) -> PyResult<PyQTensor> { use ::candle::quantized; let res = match quantized_dtype.to_lowercase().as_str() { - "q2k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ2K>(self), - "q3k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ3K>(self), - "q4_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4_0>(self), - "q4_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4_1>(self), - "q4k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4K>(self), - "q5_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5_0>(self), - "q5_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5_1>(self), - "q5k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5K>(self), - "q6k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ6K>(self), - "q8_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8_0>(self), - "q8_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8_1>(self), - "q8k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8K>(self), - "f16" => quantized::QTensor::quantize::<f16>(self), - "f32" => quantized::QTensor::quantize::<f32>(self), + "q2k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q2K), + "q3k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q3K), + "q4_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4_0), + "q4_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4_1), + "q4k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4K), + "q5_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5_0), + "q5_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5_1), + "q5k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5K), + "q6k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q6K), + "q8_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8_0), + "q8_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8_1), + "q8k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8K), + "f16" => quantized::QTensor::quantize(self, quantized::GgmlDType::F16), + "f32" => quantized::QTensor::quantize(self, quantized::GgmlDType::F32), dt => { return Err(PyErr::new::<PyValueError, _>(format!( "unknown quantized-dtype {dt}" @@ -1278,13 +1278,19 @@ fn save_safetensors( } #[pyfunction] -#[pyo3(text_signature = "(path:Union[str,PathLike])")] +#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")] /// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors, /// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary. /// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any], List[str]] -fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> { +fn load_ggml( + path: &str, + device: Option<PyDevice>, + py: Python<'_>, +) -> PyResult<(PyObject, PyObject, PyObject)> { let mut file = std::fs::File::open(path)?; - let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?; + let device = device.unwrap_or(PyDevice::Cpu).as_device()?; + let ggml = + ::candle::quantized::ggml_file::Content::read(&mut file, &device).map_err(wrap_err)?; let tensors = ggml .tensors .into_iter() @@ -1313,11 +1319,16 @@ fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObje } #[pyfunction] -#[pyo3(text_signature = "(path:Union[str,PathLike])")] +#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")] /// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors, /// and the second maps metadata keys to metadata values. /// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any]] -fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> { +fn load_gguf( + path: &str, + device: Option<PyDevice>, + py: Python<'_>, +) -> PyResult<(PyObject, PyObject)> { + let device = device.unwrap_or(PyDevice::Cpu).as_device()?; use ::candle::quantized::gguf_file; fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult<PyObject> { let v: PyObject = match v { @@ -1349,7 +1360,7 @@ fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> { .tensor_infos .keys() .map(|key| { - let qtensor = gguf.tensor(&mut file, key)?; + let qtensor = gguf.tensor(&mut file, key, &device)?; Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py))) }) .collect::<::candle::Result<Vec<_>>>() |