summaryrefslogtreecommitdiff
path: root/candle-core/src/quantized/ggml_file.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-16 12:41:07 +0100
committerGitHub <noreply@github.com>2023-08-16 12:41:07 +0100
commit3071134788334c972d9e356f53887d2b2ff026b7 (patch)
treeadbb58e3babee5d62fa6150bde4f9bb03770607c /candle-core/src/quantized/ggml_file.rs
parentfec87e86f50da78656a0fb28fc254390435fb3fd (diff)
downloadcandle-3071134788334c972d9e356f53887d2b2ff026b7.tar.gz
candle-3071134788334c972d9e356f53887d2b2ff026b7.tar.bz2
candle-3071134788334c972d9e356f53887d2b2ff026b7.zip
Get the ggml based llama to generate some text. (#464)
* Add more stats to the ggml example. * Build a quantized model from the file content. * Move the tensor retrieval in the main crate. * Start adding the forward pass. * Add more to the forward pass of the quantized llama. * Apply the attention layers. * Add the sampling loop. * Get the sampling loop to work. * Minor tweak. * Add a quantize/dequantize test. * Bugfix. * Add a comment + swap the order. * Bugfixes.
Diffstat (limited to 'candle-core/src/quantized/ggml_file.rs')
-rw-r--r--candle-core/src/quantized/ggml_file.rs18
1 files changed, 14 insertions, 4 deletions
diff --git a/candle-core/src/quantized/ggml_file.rs b/candle-core/src/quantized/ggml_file.rs
index ee23cdde..7afb8670 100644
--- a/candle-core/src/quantized/ggml_file.rs
+++ b/candle-core/src/quantized/ggml_file.rs
@@ -3,6 +3,7 @@
use super::{k_quants, GgmlDType};
use crate::Result;
use byteorder::{LittleEndian, ReadBytesExt};
+use std::collections::HashMap;
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -163,6 +164,9 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
let ggml_dtype = GgmlDType::from_u32(ggml_dtype)?;
let mut dims = vec![0u32; n_dims as usize];
reader.read_u32_into::<LittleEndian>(&mut dims)?;
+ // The dimensions are stored in reverse order, see for example:
+ // https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/convert.py#L969
+ dims.reverse();
let mut name = vec![0u8; name_len as usize];
reader.read_exact(&mut name)?;
let name = String::from_utf8_lossy(&name).into_owned();
@@ -174,7 +178,6 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
let tensor_elems = dims.iter().product::<usize>();
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
- println!("{name} {ggml_dtype:?} {dims:?}");
// TODO: Mmap version to avoid copying the data around?
let mut raw_data = vec![0u8; size_in_bytes];
reader.read_exact(&mut raw_data)?;
@@ -188,7 +191,7 @@ pub struct Content {
pub magic: VersionedMagic,
pub hparams: HParams,
pub vocab: Vocab,
- pub tensors: Vec<(String, super::QTensor)>,
+ pub tensors: HashMap<String, super::QTensor>,
}
impl Content {
@@ -199,11 +202,11 @@ impl Content {
let magic = VersionedMagic::read(reader)?;
let hparams = HParams::read(reader)?;
let vocab = Vocab::read(reader, hparams.n_vocab as usize)?;
- let mut tensors = vec![];
+ let mut tensors = HashMap::new();
while reader.stream_position()? != last_position {
let (name, tensor) = read_one_tensor(reader, magic)?;
- tensors.push((name, tensor))
+ tensors.insert(name, tensor);
}
Ok(Self {
magic,
@@ -212,4 +215,11 @@ impl Content {
tensors,
})
}
+
+ pub fn remove(&mut self, name: &str) -> Result<super::QTensor> {
+ match self.tensors.remove(name) {
+ None => crate::bail!("cannot find tensor with name '{name}'"),
+ Some(tensor) => Ok(tensor),
+ }
+ }
}