diff options
Diffstat (limited to 'candle-core/src/ggml.rs')
-rw-r--r-- | candle-core/src/ggml.rs | 582 |
1 files changed, 582 insertions, 0 deletions
diff --git a/candle-core/src/ggml.rs b/candle-core/src/ggml.rs new file mode 100644 index 00000000..4a5d4fa0 --- /dev/null +++ b/candle-core/src/ggml.rs @@ -0,0 +1,582 @@ +//! Support for the GGML file format. + +use crate::{DType, Device, Result, Tensor}; +use byteorder::{LittleEndian, ReadBytesExt}; +use half::f16; + +// Default to QK_K 256 rather than 64. +pub const QK_K: usize = 256; +pub const K_SCALE_SIZE: usize = 12; + +pub const QK4_0: usize = 32; +pub const QK4_1: usize = 32; +pub const QK5_0: usize = 32; +pub const QK5_1: usize = 32; +pub const QK8_0: usize = 32; +pub const QK8_1: usize = 32; + +#[repr(C)] +struct BlockQ4_0 { + d: f16, + qs: [u8; QK4_0 / 2], +} +const _: () = assert!(std::mem::size_of::<BlockQ4_0>() == 18); + +#[repr(C)] +struct BlockQ4_1 { + d: f16, + m: f16, + qs: [u8; QK4_1 / 2], +} +const _: () = assert!(std::mem::size_of::<BlockQ4_1>() == 20); + +#[repr(C)] +struct BlockQ5_0 { + d: f16, + qh: [u8; 4], + qs: [u8; QK5_0 / 2], +} +const _: () = assert!(std::mem::size_of::<BlockQ5_0>() == 22); + +#[repr(C)] +struct BlockQ5_1 { + d: f16, + m: f16, + qh: [u8; 4], + qs: [u8; QK5_1 / 2], +} +const _: () = assert!(std::mem::size_of::<BlockQ5_1>() == 24); + +#[repr(C)] +struct BlockQ8_0 { + d: f16, + qs: [u8; QK8_0], +} +const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34); + +#[repr(C)] +struct BlockQ8_1 { + d: f16, + s: f16, + qs: [u8; QK8_1], +} +const _: () = assert!(std::mem::size_of::<BlockQ8_1>() == 36); + +#[repr(C)] +struct BlockQ2K { + scales: [u8; QK_K / 16], + qs: [u8; QK_K / 4], + d: f16, + dmin: f16, +} +const _: () = assert!(QK_K / 16 + QK_K / 4 + 2 * 2 == std::mem::size_of::<BlockQ2K>()); + +#[repr(C)] +struct BlockQ3K { + hmask: [u8; QK_K / 8], + qs: [u8; QK_K / 4], + scales: [u8; 12], + d: f16, +} +const _: () = assert!(QK_K / 8 + QK_K / 4 + 12 + 2 == std::mem::size_of::<BlockQ3K>()); + +// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/k_quants.h#L82 +#[repr(C)] +struct BlockQ4K { + d: f16, + dmin: f16, + scales: [u8; K_SCALE_SIZE], + qs: [u8; QK_K / 2], +} +const _: () = assert!(QK_K / 2 + K_SCALE_SIZE + 2 * 2 == std::mem::size_of::<BlockQ4K>()); + +#[repr(C)] +struct BlockQ5K { + d: f16, + dmin: f16, + scales: [u8; K_SCALE_SIZE], + qh: [u8; QK_K / 8], + qs: [u8; QK_K / 2], +} +const _: () = + assert!(QK_K / 8 + QK_K / 2 + 2 * 2 + K_SCALE_SIZE == std::mem::size_of::<BlockQ5K>()); + +#[repr(C)] +struct BlockQ6K { + ql: [u8; QK_K / 2], + qh: [u8; QK_K / 4], + scales: [i8; QK_K / 16], + d: f16, +} +const _: () = assert!(3 * QK_K / 4 + QK_K / 16 + 2 == std::mem::size_of::<BlockQ6K>()); + +// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L354 +fn dequantize_row_q2k(xs: &[BlockQ2K], ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + if k % QK_K != 0 { + crate::bail!("dequantize_row_q2k: {k} is not divisible by {QK_K}") + } + let mut ys_index = 0; + for x in xs { + let d = x.d.to_f32(); + let min = x.dmin.to_f32(); + let q = &x.qs; + + let mut is = 0; + for n in (0..QK_K).step_by(128) { + // Step by 32 over q. + let q = &q[n / 4..]; + let mut shift = 0; + for _j in 0..4 { + let sc = x.scales[is]; + is += 1; + let dl = d * (sc & 0xF) as f32; + let ml = min * (sc >> 4) as f32; + for q in &q[..16] { + let y = dl * ((q >> shift) & 3) as i8 as f32 - ml; + ys[ys_index] = y; + ys_index += 1; + } + + let sc = x.scales[is]; + is += 1; + let dl = d * (sc & 0xF) as f32; + let ml = min * (sc >> 4) as f32; + for q in &q[16..32] { + let y = dl * ((q >> shift) & 3) as i8 as f32 - ml; + ys[ys_index] = y; + ys_index += 1; + } + + shift += 2; + } + } + } + Ok(()) +} + +fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) { + if j < 4 { + let d = q[j] & 63; + let m = q[j + 4] & 63; + (d, m) + } else { + let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); + let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4); + (d, m) + } +} +// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L735 +fn dequantize_row_q4k(xs: &[BlockQ4K], ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + if k % QK_K != 0 { + crate::bail!("dequantize_row_q4k: {k} is not divisible by {QK_K}") + } + let mut ys_index = 0; + for x in xs.iter() { + let d = x.d.to_f32(); + let min = x.dmin.to_f32(); + let q = &x.qs; + let mut is = 0; + for j in (0..QK_K).step_by(64) { + let q = &q[j / 2..j / 2 + 32]; + let (sc, m) = get_scale_min_k4(is, &x.scales); + let d1 = d * sc as f32; + let m1 = min * m as f32; + let (sc, m) = get_scale_min_k4(is + 1, &x.scales); + let d2 = d * sc as f32; + let m2 = min * m as f32; + for q in q { + let y = d1 * (q & 0xF) as f32 - m1; + ys[ys_index] = y; + ys_index += 1; + } + for q in q { + let y = d2 * (q >> 4) as f32 - m2; + ys[ys_index] = y; + ys_index += 1; + } + is += 2; + } + } + Ok(()) +} + +// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533 +fn dequantize_row_q3k(_xs: &[BlockQ3K], _ys: &mut [f32]) -> Result<()> { + todo!() +} + +// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L928 +fn dequantize_row_q5k(xs: &[BlockQ5K], ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + if k % QK_K != 0 { + crate::bail!("dequantize_row_q5k: {k} is not divisible by {QK_K}") + } + let mut ys_index = 0; + for x in xs.iter() { + let d = x.d.to_f32(); + let min = x.dmin.to_f32(); + let ql = &x.qs; + let qh = &x.qh; + let mut is = 0; + let mut u1 = 1; + let mut u2 = 2; + for j in (0..QK_K).step_by(64) { + let ql = &ql[j / 2..j / 2 + 32]; + let (sc, m) = get_scale_min_k4(is, &x.scales); + let d1 = d * sc as f32; + let m1 = min * m as f32; + let (sc, m) = get_scale_min_k4(is + 1, &x.scales); + let d2 = d * sc as f32; + let m2 = min * m as f32; + for (ql, qh) in ql.iter().zip(qh) { + let to_add = if qh & u1 != 0 { 16 } else { 1 }; + let y = d1 * ((ql & 0xF) + to_add) as f32 - m1; + ys[ys_index] = y; + ys_index += 1; + } + for (ql, qh) in ql.iter().zip(qh) { + let to_add = if qh & u2 != 0 { 16 } else { 1 }; + let y = d2 * ((ql >> 4) + to_add) as f32 - m2; + ys[ys_index] = y; + ys_index += 1; + } + is += 2; + u1 <<= 2; + u2 <<= 2; + } + } + Ok(()) +} + +// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L1067 +fn dequantize_row_q6k(xs: &[BlockQ6K], ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + if k % QK_K != 0 { + crate::bail!("dequantize_row_q6k: {k} is not divisible by {QK_K}") + } + for x in xs.iter() { + let d = x.d.to_f32(); + let ql = &x.ql; + let qh = &x.qh; + let sc = &x.scales; + for n in (0..QK_K).step_by(128) { + let idx = n / 128; + let ys = &mut ys[n..]; + let sc = &sc[8 * idx..]; + let ql = &ql[64 * idx..]; + let qh = &qh[32 * idx..]; + for l in 0..32 { + let is = l / 16; + let q1 = ((ql[l] & 0xF) | ((qh[l] & 3) << 4)) as i8 - 32; + let q2 = ((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i8 - 32; + let q3 = ((ql[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i8 - 32; + let q4 = ((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i8 - 32; + ys[l] = d * sc[is] as f32 * q1 as f32; + ys[l + 32] = d * sc[is + 2] as f32 * q2 as f32; + ys[l + 64] = d * sc[is + 4] as f32 * q3 as f32; + ys[l + 96] = d * sc[is + 6] as f32 * q4 as f32; + } + } + } + Ok(()) +} + +// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37 +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Magic { + Ggjt, + Ggla, + Ggmf, + Ggml, + Ggsn, +} + +impl TryFrom<u32> for Magic { + type Error = crate::Error; + fn try_from(value: u32) -> Result<Self> { + let magic = match value { + 0x67676a74 => Self::Ggjt, + 0x67676c61 => Self::Ggla, + 0x67676d66 => Self::Ggmf, + 0x67676d6c => Self::Ggml, + 0x6767736e => Self::Ggsn, + _ => crate::bail!("unknown magic {value:08x}"), + }; + Ok(magic) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum VersionedMagic { + GgmlUnversioned, + GgmfV1, + GgjtV1, + GgjtV2, + GgjtV3, +} + +impl VersionedMagic { + fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> { + let magic = reader.read_u32::<LittleEndian>()?; + let magic = Magic::try_from(magic)?; + if magic == Magic::Ggml { + return Ok(Self::GgmlUnversioned); + } + let version = reader.read_u32::<LittleEndian>()?; + let versioned_magic = match (magic, version) { + (Magic::Ggmf, 1) => Self::GgmfV1, + (Magic::Ggjt, 1) => Self::GgjtV1, + (Magic::Ggjt, 2) => Self::GgjtV2, + (Magic::Ggjt, 3) => Self::GgjtV3, + _ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"), + }; + Ok(versioned_magic) + } + + fn align32(&self) -> bool { + match self { + Self::GgmlUnversioned | Self::GgmfV1 => false, + Self::GgjtV1 | Self::GgjtV2 | Self::GgjtV3 => true, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HParams { + pub n_vocab: u32, + pub n_embd: u32, + pub n_mult: u32, + pub n_head: u32, + pub n_layer: u32, + pub n_rot: u32, + pub ftype: u32, +} + +impl HParams { + fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> { + let n_vocab = reader.read_u32::<LittleEndian>()?; + let n_embd = reader.read_u32::<LittleEndian>()?; + let n_mult = reader.read_u32::<LittleEndian>()?; + let n_head = reader.read_u32::<LittleEndian>()?; + let n_layer = reader.read_u32::<LittleEndian>()?; + let n_rot = reader.read_u32::<LittleEndian>()?; + let ftype = reader.read_u32::<LittleEndian>()?; + Ok(Self { + n_vocab, + n_embd, + n_mult, + n_head, + n_layer, + n_rot, + ftype, + }) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Vocab { + pub token_score_pairs: Vec<(Vec<u8>, f32)>, +} + +impl Vocab { + fn read<R: std::io::Read>(reader: &mut R, n_vocab: usize) -> Result<Self> { + // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L556 + let mut token_score_pairs = Vec::with_capacity(n_vocab); + for _index in 0..n_vocab { + let len = reader.read_u32::<LittleEndian>()? as usize; + let mut word = vec![0u8; len]; + reader.read_exact(&mut word)?; + let score = reader.read_f32::<LittleEndian>()?; + token_score_pairs.push((word, score)) + } + Ok(Self { token_score_pairs }) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum GgmlDType { + F32, + F16, + Q4_0, + Q4_1, + Q5_0, + Q5_1, + Q8_0, + Q8_1, + Q2K, + Q3K, + Q4K, + Q5K, + Q6K, +} + +impl GgmlDType { + fn from_u32(u: u32) -> Result<Self> { + let dtype = match u { + 0 => Self::F32, + 1 => Self::F16, + 2 => Self::Q4_0, + 3 => Self::Q4_1, + 6 => Self::Q5_0, + 7 => Self::Q5_1, + 8 => Self::Q8_0, + 9 => Self::Q8_1, + 10 => Self::Q2K, + 11 => Self::Q3K, + 12 => Self::Q4K, + 13 => Self::Q5K, + 14 => Self::Q6K, + _ => crate::bail!("unknown dtype for tensor {u}"), + }; + Ok(dtype) + } + + fn type_size(&self) -> usize { + match self { + Self::F32 => 4, + Self::F16 => 2, + Self::Q4_0 => std::mem::size_of::<BlockQ4_0>(), + Self::Q4_1 => std::mem::size_of::<BlockQ4_1>(), + Self::Q5_0 => std::mem::size_of::<BlockQ5_0>(), + Self::Q5_1 => std::mem::size_of::<BlockQ5_1>(), + // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L932 + Self::Q8_0 => std::mem::size_of::<BlockQ8_0>(), + Self::Q8_1 => std::mem::size_of::<BlockQ8_1>(), + Self::Q2K => std::mem::size_of::<BlockQ2K>(), + Self::Q3K => std::mem::size_of::<BlockQ3K>(), + Self::Q4K => std::mem::size_of::<BlockQ4K>(), + Self::Q5K => std::mem::size_of::<BlockQ5K>(), + Self::Q6K => std::mem::size_of::<BlockQ6K>(), + } + } + + fn blck_size(&self) -> usize { + match self { + Self::F32 => 1, + Self::F16 => 1, + Self::Q4_0 => QK4_0, + Self::Q4_1 => QK4_1, + Self::Q5_0 => QK5_0, + Self::Q5_1 => QK5_1, + Self::Q8_0 => QK8_0, + Self::Q8_1 => QK8_1, + Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K => QK_K, + } + } +} + +#[derive(Debug)] +pub struct Content { + pub magic: VersionedMagic, + pub hparams: HParams, + pub vocab: Vocab, + pub tensors: Vec<(String, Tensor)>, +} + +fn read_one_tensor<R: std::io::Seek + std::io::Read>( + reader: &mut R, + magic: VersionedMagic, + device: &Device, +) -> Result<(String, Tensor)> { + let n_dims = reader.read_u32::<LittleEndian>()?; + let name_len = reader.read_u32::<LittleEndian>()?; + let dtype = reader.read_u32::<LittleEndian>()?; + let dtype = GgmlDType::from_u32(dtype)?; + let mut dims = vec![0u32; n_dims as usize]; + reader.read_u32_into::<LittleEndian>(&mut dims)?; + let mut name = vec![0u8; name_len as usize]; + reader.read_exact(&mut name)?; + let name = String::from_utf8_lossy(&name).into_owned(); + + if magic.align32() { + let pos = reader.stream_position()?; + reader.seek(std::io::SeekFrom::Current(((32 - pos % 32) % 32) as i64))?; + } + let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>(); + let tensor_elems = dims.iter().product::<usize>(); + let size_in_bytes = tensor_elems * dtype.type_size() / dtype.blck_size(); + println!("{name} {dtype:?} {dims:?}"); + // TODO: Mmap version to avoid copying the data around? + let mut raw_data = vec![0u8; size_in_bytes]; + reader.read_exact(&mut raw_data)?; + let tensor = match dtype { + GgmlDType::F32 => Tensor::from_raw_buffer(&raw_data, DType::F32, &dims, device)?, + GgmlDType::F16 => Tensor::from_raw_buffer(&raw_data, DType::F16, &dims, device)?, + GgmlDType::Q2K => { + let mut f32_data = vec![0f32; tensor_elems]; + let raw_data_ptr = raw_data.as_ptr(); + let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ2K>(); + let raw_data = + unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ2K, n_blocks) }; + dequantize_row_q2k(raw_data, &mut f32_data)?; + // Maybe we should use bf16 instead? + Tensor::from_vec(f32_data, dims, device)? + } + GgmlDType::Q3K => { + let mut f32_data = vec![0f32; tensor_elems]; + let raw_data_ptr = raw_data.as_ptr(); + let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ3K>(); + let raw_data = + unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ3K, n_blocks) }; + dequantize_row_q3k(raw_data, &mut f32_data)?; + Tensor::from_vec(f32_data, dims, device)? + } + GgmlDType::Q4K => { + let mut f32_data = vec![0f32; tensor_elems]; + let raw_data_ptr = raw_data.as_ptr(); + let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ4K>(); + let raw_data = + unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ4K, n_blocks) }; + dequantize_row_q4k(raw_data, &mut f32_data)?; + Tensor::from_vec(f32_data, dims, device)? + } + GgmlDType::Q5K => { + let mut f32_data = vec![0f32; tensor_elems]; + let raw_data_ptr = raw_data.as_ptr(); + let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ5K>(); + let raw_data = + unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ5K, n_blocks) }; + dequantize_row_q5k(raw_data, &mut f32_data)?; + Tensor::from_vec(f32_data, dims, device)? + } + GgmlDType::Q6K => { + let mut f32_data = vec![0f32; tensor_elems]; + let raw_data_ptr = raw_data.as_ptr(); + let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ6K>(); + let raw_data = + unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ6K, n_blocks) }; + dequantize_row_q6k(raw_data, &mut f32_data)?; + Tensor::from_vec(f32_data, dims, device)? + } + _ => crate::bail!("quantized type {dtype:?} used in {name} is not supported yet"), + }; + Ok((name, tensor)) +} + +impl Content { + pub fn read<R: std::io::Seek + std::io::Read>( + reader: &mut R, + device: &Device, + ) -> Result<Content> { + // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505 + let last_position = reader.seek(std::io::SeekFrom::End(0))?; + reader.seek(std::io::SeekFrom::Start(0))?; + let magic = VersionedMagic::read(reader)?; + let hparams = HParams::read(reader)?; + let vocab = Vocab::read(reader, hparams.n_vocab as usize)?; + let mut tensors = vec![]; + + while reader.stream_position()? != last_position { + let (name, tensor) = read_one_tensor(reader, magic, device)?; + tensors.push((name, tensor)) + } + Ok(Self { + magic, + hparams, + vocab, + tensors, + }) + } +} |