diff options
-rw-r--r-- | candle-core/src/quantized/k_quants.rs | 9 | ||||
-rw-r--r-- | candle-core/src/quantized/neon.rs | 368 | ||||
-rw-r--r-- | candle-examples/examples/quantized/main.rs | 368 | ||||
-rw-r--r-- | candle-examples/examples/quantized/model.rs | 367 |
4 files changed, 740 insertions, 372 deletions
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index e7404529..65fd6a6e 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -676,6 +676,9 @@ impl GgmlType for BlockQ2K { #[cfg(target_feature = "avx")] return super::avx::vec_dot_q2k_q8k(n, xs, ys); + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_q2k_q8k(n, xs, ys); + if n % QK_K != 0 { crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}") } @@ -843,6 +846,9 @@ impl GgmlType for BlockQ3K { #[cfg(target_feature = "avx")] return super::avx::vec_dot_q3k_q8k(n, xs, ys); + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_q3k_q8k(n, xs, ys); + if n % QK_K != 0 { crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}") } @@ -1301,6 +1307,9 @@ impl GgmlType for BlockQ5K { #[cfg(target_feature = "avx")] return super::avx::vec_dot_q5k_q8k(n, xs, ys); + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_q5k_q8k(n, xs, ys); + if n % QK_K != 0 { crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}") } diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index 69d616f4..7f76dadc 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -1,4 +1,6 @@ -use super::k_quants::{BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K}; +use super::k_quants::{ + BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K, +}; use crate::Result; use byteorder::{ByteOrder, LittleEndian}; @@ -282,6 +284,104 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res } #[inline(always)] +pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result<f32> { + if n % QK_K != 0 { + crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}") + } + let mut sumf = 0f32; + let mut utmp = [0u32; 4]; + const KMASK1: u32 = 0x3f3f3f3f; + const KMASK2: u32 = 0x0f0f0f0f; + const KMASK3: u32 = 0x03030303; + + unsafe { + let m4b = vdupq_n_u8(0xF); + let mone = vdupq_n_u8(1); + let mtwo = vdupq_n_u8(2); + + for (x, y) in xs.iter().zip(ys.iter()) { + let d = y.d * x.d.to_f32(); + let dmin = y.d * x.dmin.to_f32(); + + let q8sums = vpaddq_s16( + vld1q_s16(y.bsums.as_ptr()), + vld1q_s16(y.bsums.as_ptr().add(8)), + ); + + LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]); + + utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4); + let uaux = utmp[1] & KMASK1; + utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4); + utmp[2] = uaux; + utmp[0] &= KMASK1; + + let mins8 = vld1_u8((utmp.as_ptr() as *const u8).add(8)); + let mins = vreinterpretq_s16_u16(vmovl_u8(mins8)); + let prod = vaddq_s32( + vmull_s16(vget_low_s16(q8sums), vget_low_s16(mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)), + ); + let sumi_mins = vaddvq_s32(prod); + + let mut scales = utmp.as_ptr() as *const u8; + + let mut q5 = x.qs.as_ptr(); + let mut q8 = y.qs.as_ptr(); + + let mut qhbits = vld1q_u8_x2(x.qh.as_ptr()); + + let mut sumi = 0i32; + + for _j in 0..QK_K / 64 { + let q5bits = vld1q_u8_x2(q5); + q5 = q5.add(32); + let q8bytes = vld1q_s8_x4(q8); + q8 = q8.add(64); + + let q5h_0 = vshlq_n_u8(vandq_u8(mone, qhbits.0), 4); + let q5h_1 = vshlq_n_u8(vandq_u8(mone, qhbits.1), 4); + let q5h_2 = vshlq_n_u8(vandq_u8(mtwo, qhbits.0), 3); + let q5h_3 = vshlq_n_u8(vandq_u8(mtwo, qhbits.1), 3); + qhbits.0 = vshrq_n_u8(qhbits.0, 2); + qhbits.1 = vshrq_n_u8(qhbits.1, 2); + + let q5bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.0, m4b), q5h_0)); + let q5bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.1, m4b), q5h_1)); + let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2)); + let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3)); + + // TODO: dotprod + + let p0 = vaddq_s16( + vmull_s8(vget_low_s8(q5bytes_0), vget_low_s8(q8bytes.0)), + vmull_s8(vget_high_s8(q5bytes_0), vget_high_s8(q8bytes.0)), + ); + let p1 = vaddq_s16( + vmull_s8(vget_low_s8(q5bytes_1), vget_low_s8(q8bytes.1)), + vmull_s8(vget_high_s8(q5bytes_1), vget_high_s8(q8bytes.1)), + ); + sumi += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * *scales as i32; + scales = scales.add(1); + + let p2 = vaddq_s16( + vmull_s8(vget_low_s8(q5bytes_2), vget_low_s8(q8bytes.2)), + vmull_s8(vget_high_s8(q5bytes_2), vget_high_s8(q8bytes.2)), + ); + let p3 = vaddq_s16( + vmull_s8(vget_low_s8(q5bytes_3), vget_low_s8(q8bytes.3)), + vmull_s8(vget_high_s8(q5bytes_3), vget_high_s8(q8bytes.3)), + ); + sumi += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * *scales as i32; + scales = scales.add(1); + } + sumf += d * sumi as f32 - dmin * sumi_mins as f32; + } + } + Ok(sumf) +} + +#[inline(always)] pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> { if n % QK_K != 0 { crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}") @@ -289,9 +389,9 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res let mut sumf = 0f32; let mut utmp = [0u32; 4]; let mut scales = [0u8; 16]; - let kmask1: u32 = 0x3f3f3f3f; - let kmask2: u32 = 0x0f0f0f0f; - let kmask3: u32 = 0x03030303; + const KMASK1: u32 = 0x3f3f3f3f; + const KMASK2: u32 = 0x0f0f0f0f; + const KMASK3: u32 = 0x03030303; unsafe { let m4b = vdupq_n_u8(0xF); @@ -309,13 +409,13 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res let mins8 = vld1_u32( [ - utmp[1] & kmask1, - ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), + utmp[1] & KMASK1, + ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4), ] .as_ptr(), ); - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[0] &= kmask1; + utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4); + utmp[0] &= KMASK1; let mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); let prod = vaddq_s32( @@ -373,3 +473,255 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res } Ok(sumf) } + +#[inline(always)] +pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result<f32> { + if n % QK_K != 0 { + crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}") + } + let mut sumf = 0f32; + let mut utmp = [0u32; 4]; + let mut aux = [0u32; 3]; + const KMASK1: u32 = 0x03030303; + const KMASK2: u32 = 0x0f0f0f0f; + + unsafe { + let m3b = vdupq_n_u8(0x3); + let m0 = vdupq_n_u8(1); + let m1 = vshlq_n_u8(m0, 1); + let m2 = vshlq_n_u8(m0, 2); + let m3 = vshlq_n_u8(m0, 3); + for (x, y) in xs.iter().zip(ys.iter()) { + let d = y.d * x.d.to_f32(); + let mut q3 = x.qs.as_ptr(); + let qh = x.hmask.as_ptr(); + let mut q8 = y.qs.as_ptr(); + + let mut qhbits = vld1q_u8_x2(qh); + + let mut isum = 0i32; + + // Set up scales + LittleEndian::read_u32_into(&x.scales, &mut aux); + + utmp[3] = ((aux[1] >> 4) & KMASK2) | (((aux[2] >> 6) & KMASK1) << 4); + utmp[2] = ((aux[0] >> 4) & KMASK2) | (((aux[2] >> 4) & KMASK1) << 4); + utmp[1] = (aux[1] & KMASK2) | (((aux[2] >> 2) & KMASK1) << 4); + utmp[0] = (aux[0] & KMASK2) | ((aux[2] & KMASK1) << 4); + + let mut scale = utmp.as_mut_ptr() as *mut i8; + for j in 0..16 { + *scale.add(j) -= 32i8 + } + + for j in 0..QK_K / 128 { + let q3bits = vld1q_u8_x2(q3); + q3 = q3.add(32); + let q8bytes_1 = vld1q_s8_x4(q8); + q8 = q8.add(64); + let q8bytes_2 = vld1q_s8_x4(q8); + q8 = q8.add(64); + + let q3h_0 = vshlq_n_u8(vbicq_u8(m0, qhbits.0), 2); + let q3h_1 = vshlq_n_u8(vbicq_u8(m0, qhbits.1), 2); + let q3h_2 = vshlq_n_u8(vbicq_u8(m1, qhbits.0), 1); + let q3h_3 = vshlq_n_u8(vbicq_u8(m1, qhbits.1), 1); + + let q3bytes_0 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(q3bits.0, m3b)), + vreinterpretq_s8_u8(q3h_0), + ); + let q3bytes_1 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(q3bits.1, m3b)), + vreinterpretq_s8_u8(q3h_1), + ); + let q3bytes_2 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 2), m3b)), + vreinterpretq_s8_u8(q3h_2), + ); + let q3bytes_3 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 2), m3b)), + vreinterpretq_s8_u8(q3h_3), + ); + + // TODO: dotprod + let p0 = vaddq_s16( + vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_1.0)), + vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_1.0)), + ); + let p1 = vaddq_s16( + vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_1.1)), + vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_1.1)), + ); + let p2 = vaddq_s16( + vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_1.2)), + vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_1.2)), + ); + let p3 = vaddq_s16( + vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_1.3)), + vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_1.3)), + ); + isum += vaddvq_s16(p0) as i32 * *scale as i32 + + vaddvq_s16(p1) as i32 * *scale.add(1) as i32 + + vaddvq_s16(p2) as i32 * *scale.add(2) as i32 + + vaddvq_s16(p3) as i32 * *scale.add(3) as i32; + scale = scale.add(4); + + let q3h_0 = vbicq_u8(m2, qhbits.0); + let q3h_1 = vbicq_u8(m2, qhbits.1); + let q3h_2 = vshrq_n_u8(vbicq_u8(m3, qhbits.0), 1); + let q3h_3 = vshrq_n_u8(vbicq_u8(m3, qhbits.1), 1); + + let q3bytes_0 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 4), m3b)), + vreinterpretq_s8_u8(q3h_0), + ); + let q3bytes_1 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 4), m3b)), + vreinterpretq_s8_u8(q3h_1), + ); + let q3bytes_2 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 6), m3b)), + vreinterpretq_s8_u8(q3h_2), + ); + let q3bytes_3 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 6), m3b)), + vreinterpretq_s8_u8(q3h_3), + ); + + // TODO: dotprod + let p0 = vaddq_s16( + vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_2.0)), + vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_2.0)), + ); + let p1 = vaddq_s16( + vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_2.1)), + vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_2.1)), + ); + let p2 = vaddq_s16( + vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_2.2)), + vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_2.2)), + ); + let p3 = vaddq_s16( + vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_2.3)), + vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_2.3)), + ); + isum += vaddvq_s16(p0) as i32 * *scale as i32 + + vaddvq_s16(p1) as i32 * *scale.add(1) as i32 + + vaddvq_s16(p2) as i32 * *scale.add(2) as i32 + + vaddvq_s16(p3) as i32 * *scale.add(3) as i32; + scale = scale.add(4); + + if j == 0 { + qhbits.0 = vshrq_n_u8(qhbits.0, 4); + qhbits.1 = vshrq_n_u8(qhbits.1, 4); + } + } + sumf += d * isum as f32; + } + } + Ok(sumf) +} + +#[inline(always)] +pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> { + if n % QK_K != 0 { + crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}") + } + let mut sumf = 0f32; + let mut aux = [0u8; 16]; + + unsafe { + let m3 = vdupq_n_u8(0x3); + let m4 = vdupq_n_u8(0xF); + + for (x, y) in xs.iter().zip(ys.iter()) { + let d = y.d * x.d.to_f32(); + let dmin = -y.d * x.dmin.to_f32(); + + let mut q2 = x.qs.as_ptr(); + let mut q8 = y.qs.as_ptr(); + let sc = x.scales.as_ptr(); + + let mins_and_scales = vld1q_u8(sc); + let scales = vandq_u8(mins_and_scales, m4); + vst1q_u8(aux.as_mut_ptr(), scales); + + let mins = vshrq_n_u8(mins_and_scales, 4); + let q8sums = vld1q_s16_x2(y.bsums.as_ptr()); + let mins16 = int16x8x2_t( + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins))), + ); + let s0 = vaddq_s32( + vmull_s16(vget_low_s16(mins16.0), vget_low_s16(q8sums.0)), + vmull_s16(vget_high_s16(mins16.0), vget_high_s16(q8sums.0)), + ); + let s1 = vaddq_s32( + vmull_s16(vget_low_s16(mins16.1), vget_low_s16(q8sums.1)), + vmull_s16(vget_high_s16(mins16.1), vget_high_s16(q8sums.1)), + ); + sumf += dmin * vaddvq_s32(vaddq_s32(s0, s1)) as f32; + + let mut isum = 0i32; + let mut is = 0usize; + + // TODO: dotprod + + for _j in 0..QK_K / 128 { + let q2bits = vld1q_u8_x2(q2); + q2 = q2.add(32); + + let q8bytes = vld1q_s8_x2(q8); + q8 = q8.add(32); + let mut q2bytes = int8x16x2_t( + vreinterpretq_s8_u8(vandq_u8(q2bits.0, m3)), + vreinterpretq_s8_u8(vandq_u8(q2bits.1, m3)), + ); + isum += multiply_accum_with_scale(&aux, is, 0, q2bytes, q8bytes); + + let q8bytes = vld1q_s8_x2(q8); + q8 = q8.add(32); + q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 2), m3)); + q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 2), m3)); + isum += multiply_accum_with_scale(&aux, is, 2, q2bytes, q8bytes); + + let q8bytes = vld1q_s8_x2(q8); + q8 = q8.add(32); + q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 4), m3)); + q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 4), m3)); + isum += multiply_accum_with_scale(&aux, is, 4, q2bytes, q8bytes); + + let q8bytes = vld1q_s8_x2(q8); + q8 = q8.add(32); + q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 6), m3)); + q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 6), m3)); + isum += multiply_accum_with_scale(&aux, is, 6, q2bytes, q8bytes); + + is += 8; + } + sumf += d * isum as f32; + } + } + Ok(sumf) +} + +#[inline(always)] +unsafe fn multiply_accum_with_scale( + aux: &[u8; 16], + is: usize, + index: usize, + q2bytes: int8x16x2_t, + q8bytes: int8x16x2_t, +) -> i32 { + let p1 = vaddq_s16( + vmull_s8(vget_low_s8(q2bytes.0), vget_low_s8(q8bytes.0)), + vmull_s8(vget_high_s8(q2bytes.0), vget_high_s8(q8bytes.0)), + ); + let p2 = vaddq_s16( + vmull_s8(vget_low_s8(q2bytes.1), vget_low_s8(q8bytes.1)), + vmull_s8(vget_high_s8(q2bytes.1), vget_high_s8(q8bytes.1)), + ); + vaddvq_s16(p1) as i32 * aux[is + index] as i32 + + vaddvq_s16(p2) as i32 * aux[is + 1 + index] as i32 +} diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index a1e3eabd..53be19b9 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -5,377 +5,17 @@ extern crate intel_mkl_src; extern crate accelerate_src; use clap::{Parser, ValueEnum}; -use std::collections::HashMap; use std::io::Write; use tokenizers::Tokenizer; -use candle::quantized::QTensor; use candle::quantized::{ggml_file, gguf_file}; -use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{Embedding, Module}; +use candle::{Device, Tensor}; use candle_transformers::generation::LogitsProcessor; -const MAX_SEQ_LEN: usize = 4096; -const DEFAULT_PROMPT: &str = "My favorite theorem is "; - -struct RmsNorm { - inner: candle_nn::LayerNorm, - span: tracing::Span, -} - -impl RmsNorm { - fn new(scale: QTensor, eps: f32) -> Result<Self> { - let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); - let scale = scale.dequantize(&Device::Cpu)?; - let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64); - Ok(Self { inner, span }) - } - - fn forward(&self, x: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - -// QMatMul wrapper adding some tracing. -struct QMatMul { - inner: candle::quantized::QMatMul, - span: tracing::Span, -} - -impl QMatMul { - fn from_qtensor(qtensor: QTensor) -> Self { - let inner = candle::quantized::QMatMul::from_qtensor(qtensor); - let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); - Self { inner, span } - } - - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - self.inner.forward(xs) - } -} - -struct LayerWeights { - attention_wq: QMatMul, - attention_wk: QMatMul, - attention_wv: QMatMul, - attention_wo: QMatMul, - attention_norm: RmsNorm, - feed_forward_w1: QMatMul, - feed_forward_w2: QMatMul, - feed_forward_w3: QMatMul, - ffn_norm: RmsNorm, - n_head: usize, - n_kv_head: usize, - head_dim: usize, - cos: Tensor, - sin: Tensor, - kv_cache: Option<(Tensor, Tensor)>, - span_attn: tracing::Span, - span_rot: tracing::Span, - span_mlp: tracing::Span, -} - -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { - let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; - let m = mask.where_cond(&on_true, on_false)?; - Ok(m) -} - -impl LayerWeights { - fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { - let _enter = self.span_rot.enter(); - let (b_sz, n_head, seq_len, n_embd) = x.dims4()?; - let cos = self - .cos - .narrow(0, index_pos, seq_len)? - .reshape((seq_len, n_embd / 2, 1))?; - let sin = self - .sin - .narrow(0, index_pos, seq_len)? - .reshape((seq_len, n_embd / 2, 1))?; - let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?; - let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?; - // This mimics the llama.cpp behavior. - // https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105 - // The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension. - // The resulting y0 and y1 are also interleaved with: - // y0 = x0*cos - x1*sin - // y1 = x0*sin + x1*cos - let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?; - let x0 = x.narrow(D::Minus1, 0, 1)?; - let x1 = x.narrow(D::Minus1, 1, 1)?; - let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?; - let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?; - let rope = Tensor::cat(&[y0, y1], D::Minus1)?; - let rope = rope.flatten_from(D::Minus2)?; - Ok(rope) - } - - fn forward_attn(&mut self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> { - let _enter = self.span_attn.enter(); - let (b_sz, seq_len, n_embd) = x.dims3()?; - let q = self.attention_wq.forward(x)?; - let k = self.attention_wk.forward(x)?; - let v = self.attention_wv.forward(x)?; - - let q = q - .reshape((b_sz, seq_len, self.n_head, self.head_dim))? - .transpose(1, 2)?; - let k = k - .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? - .transpose(1, 2)?; - let v = v - .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? - .transpose(1, 2)?; - - let q = self.apply_rotary_emb(&q, index_pos)?; - let k = self.apply_rotary_emb(&k, index_pos)?; - - let (k, v) = match &self.kv_cache { - None => (k, v), - Some((k_cache, v_cache)) => { - let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?; - let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?; - (k, v) - } - }; - self.kv_cache = Some((k.clone(), v.clone())); - - // Support for MQA, useful for 70B models. - let k = self.repeat_kv(k)?; - let v = self.repeat_kv(v)?; - - let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; - let mask = mask.broadcast_as(att.shape())?; - let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; - let att = candle_nn::ops::softmax(&att, D::Minus1)?; - // Convert to contiguous as matmul doesn't support strided vs for now. - let y = att.matmul(&v.contiguous()?)?; - let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; - let y = self.attention_wo.forward(&y)?; - Ok(y) - } +mod model; +use model::ModelWeights; - fn repeat_kv(&self, x: Tensor) -> Result<Tensor> { - let n_rep = self.n_head / self.n_kv_head; - if n_rep == 1 { - Ok(x) - } else { - let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?; - let x = x - .unsqueeze(2)? - .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))? - .reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?; - Ok(x) - } - } -} - -struct ModelWeights { - tok_embeddings: Embedding, - layers: Vec<LayerWeights>, - norm: RmsNorm, - output: QMatMul, - masks: HashMap<usize, Tensor>, - span: tracing::Span, - span_output: tracing::Span, -} - -fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> { - let theta: Vec<_> = (0..head_dim) - .step_by(2) - .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32)) - .collect(); - let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?; - let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)? - .to_dtype(DType::F32)? - .reshape((MAX_SEQ_LEN, 1))? - .matmul(&theta.reshape((1, theta.elem_count()))?)?; - let cos = idx_theta.cos()?; - let sin = idx_theta.sin()?; - Ok((cos, sin)) -} - -impl ModelWeights { - fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> { - let cpu = &Device::Cpu; - let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; - let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?; - let tok_embeddings = ct.remove("tok_embeddings.weight")?; - let tok_embeddings = tok_embeddings.dequantize(cpu)?; - let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?; - let output = ct.remove("output.weight")?; - let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize); - for layer_idx in 0..ct.hparams.n_layer { - let prefix = format!("layers.{layer_idx}"); - let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?; - let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?; - let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?; - let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?; - let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?; - let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?; - let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?; - let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?; - let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?; - let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); - let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); - let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); - layers.push(LayerWeights { - attention_wq: QMatMul::from_qtensor(attention_wq), - attention_wk: QMatMul::from_qtensor(attention_wk), - attention_wv: QMatMul::from_qtensor(attention_wv), - attention_wo: QMatMul::from_qtensor(attention_wo), - attention_norm: RmsNorm::new(attention_norm, 1e-5)?, - feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1), - feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2), - feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3), - ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?, - n_head: ct.hparams.n_head as usize, - n_kv_head: ct.hparams.n_head as usize / gqa, - head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize, - cos: cos.clone(), - sin: sin.clone(), - kv_cache: None, - span_attn, - span_rot, - span_mlp, - }) - } - let span = tracing::span!(tracing::Level::TRACE, "model"); - let span_output = tracing::span!(tracing::Level::TRACE, "output"); - Ok(Self { - tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize), - layers, - norm, - output: QMatMul::from_qtensor(output), - masks: HashMap::new(), - span, - span_output, - }) - } - - fn from_gguf<R: std::io::Seek + std::io::Read>( - ct: gguf_file::Content, - reader: &mut R, - ) -> Result<Self> { - let cpu = &Device::Cpu; - let md_get = |s: &str| match ct.metadata.get(s) { - None => candle::bail!("cannot find {s} in metadata"), - Some(v) => Ok(v), - }; - - // Parameter extraction from metadata. - let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize; - let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize; - let block_count = md_get("llama.block_count")?.to_u32()? as usize; - let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize; - let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize; - // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default. - let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()?; - - let rope_freq_base = md_get("llama.rope.freq_base") - .and_then(|m| m.to_f32()) - .unwrap_or(10000f32); - let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?; - - let tok_embeddings = ct.tensor(reader, "token_embd.weight")?; - let tok_embeddings = tok_embeddings.dequantize(cpu)?; - let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?; - let output = ct.tensor(reader, "output.weight")?; - let mut layers = Vec::with_capacity(block_count); - for layer_idx in 0..block_count { - let prefix = format!("blk.{layer_idx}"); - let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?; - let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?; - let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?; - let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?; - let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?; - let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?; - let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?; - let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?; - let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?; - let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); - let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); - let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); - layers.push(LayerWeights { - attention_wq: QMatMul::from_qtensor(attention_wq), - attention_wk: QMatMul::from_qtensor(attention_wk), - attention_wv: QMatMul::from_qtensor(attention_wv), - attention_wo: QMatMul::from_qtensor(attention_wo), - attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?, - feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1), - feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2), - feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3), - ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?, - n_head: head_count, - n_kv_head: head_count_kv, - head_dim: embedding_length / head_count, - cos: cos.clone(), - sin: sin.clone(), - kv_cache: None, - span_attn, - span_rot, - span_mlp, - }) - } - let span = tracing::span!(tracing::Level::TRACE, "model"); - let span_output = tracing::span!(tracing::Level::TRACE, "output"); - Ok(Self { - tok_embeddings: Embedding::new(tok_embeddings, embedding_length), - layers, - norm, - output: QMatMul::from_qtensor(output), - masks: HashMap::new(), - span, - span_output, - }) - } - - fn mask(&mut self, t: usize) -> Result<Tensor> { - if let Some(mask) = self.masks.get(&t) { - Ok(mask.clone()) - } else { - let mask: Vec<_> = (0..t) - .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) - .collect(); - let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?; - self.masks.insert(t, mask.clone()); - Ok(mask) - } - } - - fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> { - let (_b_sz, seq_len) = x.dims2()?; - let mask = self.mask(seq_len)?; - let _enter = self.span.enter(); - let mut layer_in = self.tok_embeddings.forward(x)?; - for layer in self.layers.iter_mut() { - let x = layer_in; - let residual = &x; - let x = layer.attention_norm.forward(&x)?; - let attn = layer.forward_attn(&x, &mask, index_pos)?; - let x = (attn + residual)?; - - // MLP - let _enter = layer.span_mlp.enter(); - let residual = &x; - let x = layer.ffn_norm.forward(&x)?; - let w1 = layer.feed_forward_w1.forward(&x)?; - let w3 = layer.feed_forward_w3.forward(&x)?; - let mlp = layer - .feed_forward_w2 - .forward(&(candle_nn::ops::silu(&w1)? * w3)?)?; - layer_in = (mlp + residual)?; - } - let x = self.norm.forward(&layer_in)?; - let x = x.i((.., seq_len - 1, ..))?; - let _enter = self.span_output.enter(); - self.output.forward(&x) - } -} +const DEFAULT_PROMPT: &str = "My favorite theorem is "; #[derive(Clone, Debug, Copy, ValueEnum)] enum Which { diff --git a/candle-examples/examples/quantized/model.rs b/candle-examples/examples/quantized/model.rs new file mode 100644 index 00000000..27ac18a9 --- /dev/null +++ b/candle-examples/examples/quantized/model.rs @@ -0,0 +1,367 @@ +use std::collections::HashMap; + +use candle::quantized::QTensor; +use candle::quantized::{ggml_file, gguf_file}; +use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn::{Embedding, Module}; + +const MAX_SEQ_LEN: usize = 4096; + +struct RmsNorm { + inner: candle_nn::LayerNorm, + span: tracing::Span, +} + +impl RmsNorm { + fn new(scale: QTensor, eps: f32) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); + let scale = scale.dequantize(&Device::Cpu)?; + let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64); + Ok(Self { inner, span }) + } + + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} + +// QMatMul wrapper adding some tracing. +struct QMatMul { + inner: candle::quantized::QMatMul, + span: tracing::Span, +} + +impl QMatMul { + fn from_qtensor(qtensor: QTensor) -> Self { + let inner = candle::quantized::QMatMul::from_qtensor(qtensor); + let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); + Self { inner, span } + } + + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + +struct LayerWeights { + attention_wq: QMatMul, + attention_wk: QMatMul, + attention_wv: QMatMul, + attention_wo: QMatMul, + attention_norm: RmsNorm, + feed_forward_w1: QMatMul, + feed_forward_w2: QMatMul, + feed_forward_w3: QMatMul, + ffn_norm: RmsNorm, + n_head: usize, + n_kv_head: usize, + head_dim: usize, + cos: Tensor, + sin: Tensor, + kv_cache: Option<(Tensor, Tensor)>, + span_attn: tracing::Span, + span_rot: tracing::Span, + span_mlp: tracing::Span, +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + +impl LayerWeights { + fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { + let _enter = self.span_rot.enter(); + let (b_sz, n_head, seq_len, n_embd) = x.dims4()?; + let cos = self + .cos + .narrow(0, index_pos, seq_len)? + .reshape((seq_len, n_embd / 2, 1))?; + let sin = self + .sin + .narrow(0, index_pos, seq_len)? + .reshape((seq_len, n_embd / 2, 1))?; + let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?; + let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?; + // This mimics the llama.cpp behavior. + // https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105 + // The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension. + // The resulting y0 and y1 are also interleaved with: + // y0 = x0*cos - x1*sin + // y1 = x0*sin + x1*cos + let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?; + let x0 = x.narrow(D::Minus1, 0, 1)?; + let x1 = x.narrow(D::Minus1, 1, 1)?; + let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?; + let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?; + let rope = Tensor::cat(&[y0, y1], D::Minus1)?; + let rope = rope.flatten_from(D::Minus2)?; + Ok(rope) + } + + fn forward_attn(&mut self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> { + let _enter = self.span_attn.enter(); + let (b_sz, seq_len, n_embd) = x.dims3()?; + let q = self.attention_wq.forward(x)?; + let k = self.attention_wk.forward(x)?; + let v = self.attention_wv.forward(x)?; + + let q = q + .reshape((b_sz, seq_len, self.n_head, self.head_dim))? + .transpose(1, 2)?; + let k = k + .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? + .transpose(1, 2)?; + let v = v + .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? + .transpose(1, 2)?; + + let q = self.apply_rotary_emb(&q, index_pos)?; + let k = self.apply_rotary_emb(&k, index_pos)?; + + let (k, v) = match &self.kv_cache { + None => (k, v), + Some((k_cache, v_cache)) => { + let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?; + let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?; + (k, v) + } + }; + self.kv_cache = Some((k.clone(), v.clone())); + + // Support for MQA, useful for 70B models. + let k = self.repeat_kv(k)?; + let v = self.repeat_kv(v)?; + + let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; + let mask = mask.broadcast_as(att.shape())?; + let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; + let att = candle_nn::ops::softmax(&att, D::Minus1)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + let y = att.matmul(&v.contiguous()?)?; + let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; + let y = self.attention_wo.forward(&y)?; + Ok(y) + } + + fn repeat_kv(&self, x: Tensor) -> Result<Tensor> { + let n_rep = self.n_head / self.n_kv_head; + if n_rep == 1 { + Ok(x) + } else { + let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?; + let x = x + .unsqueeze(2)? + .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))? + .reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?; + Ok(x) + } + } +} + +pub struct ModelWeights { + tok_embeddings: Embedding, + layers: Vec<LayerWeights>, + norm: RmsNorm, + output: QMatMul, + masks: HashMap<usize, Tensor>, + span: tracing::Span, + span_output: tracing::Span, +} + +fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> { + let theta: Vec<_> = (0..head_dim) + .step_by(2) + .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32)) + .collect(); + let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?; + let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)? + .to_dtype(DType::F32)? + .reshape((MAX_SEQ_LEN, 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; + let cos = idx_theta.cos()?; + let sin = idx_theta.sin()?; + Ok((cos, sin)) +} + +impl ModelWeights { + pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> { + let cpu = &Device::Cpu; + let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; + let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?; + let tok_embeddings = ct.remove("tok_embeddings.weight")?; + let tok_embeddings = tok_embeddings.dequantize(cpu)?; + let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?; + let output = ct.remove("output.weight")?; + let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize); + for layer_idx in 0..ct.hparams.n_layer { + let prefix = format!("layers.{layer_idx}"); + let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?; + let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?; + let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?; + let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?; + let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?; + let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?; + let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?; + let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?; + let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?; + let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); + let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); + let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); + layers.push(LayerWeights { + attention_wq: QMatMul::from_qtensor(attention_wq), + attention_wk: QMatMul::from_qtensor(attention_wk), + attention_wv: QMatMul::from_qtensor(attention_wv), + attention_wo: QMatMul::from_qtensor(attention_wo), + attention_norm: RmsNorm::new(attention_norm, 1e-5)?, + feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1), + feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2), + feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3), + ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?, + n_head: ct.hparams.n_head as usize, + n_kv_head: ct.hparams.n_head as usize / gqa, + head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize, + cos: cos.clone(), + sin: sin.clone(), + kv_cache: None, + span_attn, + span_rot, + span_mlp, + }) + } + let span = tracing::span!(tracing::Level::TRACE, "model"); + let span_output = tracing::span!(tracing::Level::TRACE, "output"); + Ok(Self { + tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize), + layers, + norm, + output: QMatMul::from_qtensor(output), + masks: HashMap::new(), + span, + span_output, + }) + } + + pub fn from_gguf<R: std::io::Seek + std::io::Read>( + ct: gguf_file::Content, + reader: &mut R, + ) -> Result<Self> { + let cpu = &Device::Cpu; + let md_get = |s: &str| match ct.metadata.get(s) { + None => candle::bail!("cannot find {s} in metadata"), + Some(v) => Ok(v), + }; + + // Parameter extraction from metadata. + let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize; + let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize; + let block_count = md_get("llama.block_count")?.to_u32()? as usize; + let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize; + let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize; + // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default. + let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()?; + + let rope_freq_base = md_get("llama.rope.freq_base") + .and_then(|m| m.to_f32()) + .unwrap_or(10000f32); + let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?; + + let tok_embeddings = ct.tensor(reader, "token_embd.weight")?; + let tok_embeddings = tok_embeddings.dequantize(cpu)?; + let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?; + let output = ct.tensor(reader, "output.weight")?; + let mut layers = Vec::with_capacity(block_count); + for layer_idx in 0..block_count { + let prefix = format!("blk.{layer_idx}"); + let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?; + let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?; + let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?; + let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?; + let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?; + let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?; + let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?; + let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?; + let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?; + let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); + let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); + let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); + layers.push(LayerWeights { + attention_wq: QMatMul::from_qtensor(attention_wq), + attention_wk: QMatMul::from_qtensor(attention_wk), + attention_wv: QMatMul::from_qtensor(attention_wv), + attention_wo: QMatMul::from_qtensor(attention_wo), + attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?, + feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1), + feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2), + feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3), + ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?, + n_head: head_count, + n_kv_head: head_count_kv, + head_dim: embedding_length / head_count, + cos: cos.clone(), + sin: sin.clone(), + kv_cache: None, + span_attn, + span_rot, + span_mlp, + }) + } + let span = tracing::span!(tracing::Level::TRACE, "model"); + let span_output = tracing::span!(tracing::Level::TRACE, "output"); + Ok(Self { + tok_embeddings: Embedding::new(tok_embeddings, embedding_length), + layers, + norm, + output: QMatMul::from_qtensor(output), + masks: HashMap::new(), + span, + span_output, + }) + } + + fn mask(&mut self, t: usize) -> Result<Tensor> { + if let Some(mask) = self.masks.get(&t) { + Ok(mask.clone()) + } else { + let mask: Vec<_> = (0..t) + .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) + .collect(); + let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?; + self.masks.insert(t, mask.clone()); + Ok(mask) + } + } + + pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> { + let (_b_sz, seq_len) = x.dims2()?; + let mask = self.mask(seq_len)?; + let _enter = self.span.enter(); + let mut layer_in = self.tok_embeddings.forward(x)?; + for layer in self.layers.iter_mut() { + let x = layer_in; + let residual = &x; + let x = layer.attention_norm.forward(&x)?; + let attn = layer.forward_attn(&x, &mask, index_pos)?; + let x = (attn + residual)?; + + // MLP + let _enter = layer.span_mlp.enter(); + let residual = &x; + let x = layer.ffn_norm.forward(&x)?; + let w1 = layer.feed_forward_w1.forward(&x)?; + let w3 = layer.feed_forward_w3.forward(&x)?; + let mlp = layer + .feed_forward_w2 + .forward(&(candle_nn::ops::silu(&w1)? * w3)?)?; + layer_in = (mlp + residual)?; + } + let x = self.norm.forward(&layer_in)?; + let x = x.i((.., seq_len - 1, ..))?; + let _enter = self.span_output.enter(); + self.output.forward(&x) + } +} |