diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-05-04 10:14:57 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-04 10:14:57 +0200 |
commit | b13a82a4387a55df07bec4e2eb6f7a8ebd0b98a2 (patch) | |
tree | aed5a019e7e053900ffa5be57ddfd20bdfad8582 | |
parent | 59b18d974ec3cad6963b774aa245e23f8c80414f (diff) | |
download | candle-b13a82a4387a55df07bec4e2eb6f7a8ebd0b98a2.tar.gz candle-b13a82a4387a55df07bec4e2eb6f7a8ebd0b98a2.tar.bz2 candle-b13a82a4387a55df07bec4e2eb6f7a8ebd0b98a2.zip |
Separate quantized phi-3 implementation. (#2157)
* Separate quantized phi-3 implementation.
* Integrate the quantized phi3 model.=
* Small fixes, get the generation to work properly.
* Keep the old llama implementation around.
* Change the default.
-rw-r--r-- | candle-core/src/metal_backend/mod.rs | 3 | ||||
-rw-r--r-- | candle-core/src/sort.rs | 2 | ||||
-rw-r--r-- | candle-examples/examples/quantized-phi/main.rs | 18 | ||||
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 2 | ||||
-rw-r--r-- | candle-transformers/src/models/mod.rs | 1 | ||||
-rw-r--r-- | candle-transformers/src/models/phi3.rs | 8 | ||||
-rw-r--r-- | candle-transformers/src/models/quantized_phi3.rs | 301 |
7 files changed, 323 insertions, 12 deletions
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index c0f6a844..e00566ca 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -676,9 +676,6 @@ impl BackendStorage for MetalStorage { } } - if layout.is_contiguous() { - } else { - } Ok(Self::new(buffer, device.clone(), el_count, dtype)) } diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs index 6bfa3ca7..614a37fe 100644 --- a/candle-core/src/sort.rs +++ b/candle-core/src/sort.rs @@ -178,7 +178,7 @@ impl crate::CustomOp1 for ArgSort { device.metal_device(), &command_buffer, kernels, - &name, + name, nrows, ncols, ncols_pad, diff --git a/candle-examples/examples/quantized-phi/main.rs b/candle-examples/examples/quantized-phi/main.rs index 7d255f58..e2211844 100644 --- a/candle-examples/examples/quantized-phi/main.rs +++ b/candle-examples/examples/quantized-phi/main.rs @@ -13,8 +13,9 @@ use candle::Tensor; use candle_transformers::generation::{LogitsProcessor, Sampling}; use candle_examples::token_output_stream::TokenOutputStream; -use candle_transformers::models::quantized_llama::ModelWeights as Phi3; +use candle_transformers::models::quantized_llama::ModelWeights as Phi3b; use candle_transformers::models::quantized_phi::ModelWeights as Phi2; +use candle_transformers::models::quantized_phi3::ModelWeights as Phi3; const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. "; @@ -24,6 +25,9 @@ enum Which { Phi2, #[value(name = "phi-3")] Phi3, + /// Alternative implementation of phi-3, based on llama. + #[value(name = "phi-3b")] + Phi3b, } #[derive(Parser, Debug)] @@ -84,7 +88,7 @@ struct Args { repeat_last_n: usize, /// The model size to use. - #[arg(long, default_value = "phi-2")] + #[arg(long, default_value = "phi-3b")] which: Which, } @@ -96,7 +100,7 @@ impl Args { let api = hf_hub::api::sync::Api::new()?; let repo = match self.which { Which::Phi2 => "microsoft/phi-2", - Which::Phi3 => "microsoft/Phi-3-mini-4k-instruct", + Which::Phi3 | Which::Phi3b => "microsoft/Phi-3-mini-4k-instruct", }; let api = api.model(repo.to_string()); api.get("tokenizer.json")? @@ -114,6 +118,11 @@ impl Args { Which::Phi3 => ( "microsoft/Phi-3-mini-4k-instruct-gguf", "Phi-3-mini-4k-instruct-q4.gguf", + "main", + ), + Which::Phi3b => ( + "microsoft/Phi-3-mini-4k-instruct-gguf", + "Phi-3-mini-4k-instruct-q4.gguf", "5eef2ce24766d31909c0b269fe90c817a8f263fb", ), }; @@ -145,6 +154,7 @@ fn format_size(size_in_bytes: usize) -> String { enum Model { Phi2(Phi2), Phi3(Phi3), + Phi3b(Phi3b), } impl Model { @@ -152,6 +162,7 @@ impl Model { match self { Self::Phi2(m) => m.forward(xs, pos), Self::Phi3(m) => m.forward(xs, pos), + Self::Phi3b(m) => m.forward(xs, pos), } } } @@ -203,6 +214,7 @@ fn main() -> anyhow::Result<()> { match args.which { Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?), Which::Phi3 => Model::Phi3(Phi3::from_gguf(model, &mut file, &device)?), + Which::Phi3b => Model::Phi3b(Phi3b::from_gguf(model, &mut file, &device)?), } }; println!("model built"); diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index c08e44fe..814ca0b9 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -350,7 +350,7 @@ pub fn call_unary_contiguous_tiled( let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); let tile_size = 2; - let tiles = length.div_ceil(tile_size); + let tiles = (length + tile_size - 1) / tile_size; encoder.set_compute_pipeline_state(&pipeline); diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 02f84158..de2430a2 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -40,6 +40,7 @@ pub mod quantized_mixformer; pub mod quantized_moondream; pub mod quantized_mpt; pub mod quantized_phi; +pub mod quantized_phi3; pub mod quantized_recurrent_gemma; pub mod quantized_rwkv_v5; pub mod quantized_rwkv_v6; diff --git a/candle-transformers/src/models/phi3.rs b/candle-transformers/src/models/phi3.rs index d305e175..a5e3e9a9 100644 --- a/candle-transformers/src/models/phi3.rs +++ b/candle-transformers/src/models/phi3.rs @@ -24,19 +24,19 @@ pub struct Config { } impl Config { - fn head_dim(&self) -> usize { + pub fn head_dim(&self) -> usize { self.hidden_size / self.num_attention_heads } } #[derive(Debug, Clone)] -struct RotaryEmbedding { +pub struct RotaryEmbedding { sin: Tensor, cos: Tensor, } impl RotaryEmbedding { - fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> { + pub fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> { let dim = cfg.head_dim(); let max_seq_len = cfg.max_position_embeddings; let inv_freq: Vec<_> = (0..dim) @@ -55,7 +55,7 @@ impl RotaryEmbedding { }) } - fn apply_rotary_emb_qkv( + pub fn apply_rotary_emb_qkv( &self, q: &Tensor, k: &Tensor, diff --git a/candle-transformers/src/models/quantized_phi3.rs b/candle-transformers/src/models/quantized_phi3.rs new file mode 100644 index 00000000..ef404ca0 --- /dev/null +++ b/candle-transformers/src/models/quantized_phi3.rs @@ -0,0 +1,301 @@ +use std::collections::HashMap; + +use candle::quantized::gguf_file; +use candle::quantized::QTensor; +use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; +use candle_nn::{Embedding, RmsNorm}; + +pub const MAX_SEQ_LEN: usize = 4096; + +#[derive(Debug, Clone)] +struct QLinear { + inner: candle::quantized::QMatMul, + span: tracing::Span, +} + +impl QLinear { + fn new<R: std::io::Read + std::io::Seek>( + ct: &gguf_file::Content, + r: &mut R, + name: &str, + device: &Device, + ) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); + let w = ct.tensor(r, &format!("{name}.weight"), device)?; + let inner = candle::quantized::QMatMul::from_qtensor(w)?; + Ok(Self { inner, span }) + } +} + +impl Module for QLinear { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + +#[derive(Debug, Clone)] +struct Mlp { + ffn_up: QLinear, + ffn_down: QLinear, + i_size: usize, +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let up_states = xs.apply(&self.ffn_up)?; + let gate = up_states.narrow(D::Minus1, 0, self.i_size)?; + let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?; + let up_states = (up_states * gate.silu()?)?; + up_states.apply(&self.ffn_down) + } +} + +fn rms_norm(w: QTensor, eps: f64) -> Result<RmsNorm> { + let w = w.dequantize(&w.device())?; + let rms = RmsNorm::new(w, eps); + Ok(rms) +} + +#[derive(Debug, Clone)] +struct LayerWeights { + attn_qkv: QLinear, + attn_output: QLinear, + attn_norm: RmsNorm, + ffn_norm: RmsNorm, + mlp: Mlp, + n_head: usize, + n_kv_head: usize, + head_dim: usize, + cos: Tensor, + sin: Tensor, + neg_inf: Tensor, + kv_cache: Option<(Tensor, Tensor)>, + span_attn: tracing::Span, + span_rot: tracing::Span, +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Tensor> { + let shape = mask.shape(); + let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?; + Ok(m) +} + +impl LayerWeights { + fn apply_rotary_emb(&self, xs: &Tensor, index_pos: usize) -> Result<Tensor> { + let _enter = self.span_rot.enter(); + let (_b_sz, _h, seq_len, _n_embd) = xs.dims4()?; + let cos = self.cos.narrow(0, index_pos, seq_len)?; + let sin = self.sin.narrow(0, index_pos, seq_len)?; + candle_nn::rotary_emb::rope(&xs.contiguous()?, &cos, &sin) + } + + fn forward_attn( + &mut self, + x: &Tensor, + mask: Option<&Tensor>, + index_pos: usize, + ) -> Result<Tensor> { + let _enter = self.span_attn.enter(); + let (b_sz, seq_len, n_embd) = x.dims3()?; + let qkv = self.attn_qkv.forward(x)?; + + let query_pos = self.n_head * self.head_dim; + let q = qkv.narrow(D::Minus1, 0, query_pos)?; + let k = qkv.narrow(D::Minus1, query_pos, self.n_kv_head * self.head_dim)?; + let v = qkv.narrow( + D::Minus1, + query_pos + self.n_kv_head * self.head_dim, + self.n_kv_head * self.head_dim, + )?; + + let q = q + .reshape((b_sz, seq_len, self.n_head, self.head_dim))? + .transpose(1, 2)?; + let k = k + .reshape((b_sz, seq_len, self.n_head, self.head_dim))? + .transpose(1, 2)?; + let v = v + .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? + .transpose(1, 2)?; + + let q = self.apply_rotary_emb(&q, index_pos)?.contiguous()?; + let k = self.apply_rotary_emb(&k, index_pos)?; + + let (k, v) = match &self.kv_cache { + None => (k.contiguous()?, v.contiguous()?), + Some((k_cache, v_cache)) => { + if index_pos == 0 { + (k.contiguous()?, v.contiguous()?) + } else { + let k = Tensor::cat(&[k_cache, &k], 2)?; + let v = Tensor::cat(&[v_cache, &v], 2)?; + (k.contiguous()?, v.contiguous()?) + } + } + }; + self.kv_cache = Some((k.clone(), v.clone())); + + let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?; + let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?; + + let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; + let att = match mask { + None => att, + Some(mask) => { + let mask = mask.broadcast_as(att.shape())?; + masked_fill(&att, &mask, &self.neg_inf)? + } + }; + let att = candle_nn::ops::softmax_last_dim(&att)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + let y = att.matmul(&v.contiguous()?)?; + let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; + let y = self.attn_output.forward(&y)?; + Ok(y) + } +} + +#[derive(Debug, Clone)] +pub struct ModelWeights { + tok_embeddings: Embedding, + layers: Vec<LayerWeights>, + output_norm: RmsNorm, + output: QLinear, + masks: HashMap<usize, Tensor>, + span: tracing::Span, + span_output: tracing::Span, +} + +fn precomput_freqs_cis( + head_dim: usize, + freq_base: f32, + device: &Device, +) -> Result<(Tensor, Tensor)> { + let theta: Vec<_> = (0..head_dim) + .step_by(2) + .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32)) + .collect(); + let theta = Tensor::new(theta.as_slice(), device)?; + let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? + .to_dtype(DType::F32)? + .reshape((MAX_SEQ_LEN, 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; + let cos = idx_theta.cos()?; + let sin = idx_theta.sin()?; + Ok((cos, sin)) +} + +impl ModelWeights { + pub fn from_gguf<R: std::io::Seek + std::io::Read>( + ct: gguf_file::Content, + reader: &mut R, + device: &Device, + ) -> Result<Self> { + let md_get = |s: &str| match ct.metadata.get(s) { + None => candle::bail!("cannot find {s} in metadata"), + Some(v) => Ok(v), + }; + + // Parameter extraction from metadata. + let head_count = md_get("phi3.attention.head_count")?.to_u32()? as usize; + let head_count_kv = md_get("phi3.attention.head_count_kv")?.to_u32()? as usize; + let block_count = md_get("phi3.block_count")?.to_u32()? as usize; + let embedding_length = md_get("phi3.embedding_length")?.to_u32()? as usize; + let i_size = md_get("phi3.feed_forward_length")?.to_u32()? as usize; + let rope_dim = md_get("phi3.rope.dimension_count")?.to_u32()? as usize; + let rms_eps = md_get("phi3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64; + let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device)?; + let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?; + + let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; + let tok_embeddings = tok_embeddings.dequantize(device)?; + let output_norm = rms_norm(ct.tensor(reader, "output_norm.weight", device)?, rms_eps)?; + let output = QLinear::new(&ct, reader, "output", device)?; + let mut layers = Vec::with_capacity(block_count); + for layer_idx in 0..block_count { + let prefix = format!("blk.{layer_idx}"); + let ffn_up = QLinear::new(&ct, reader, &format!("{prefix}.ffn_up"), device)?; + let ffn_down = QLinear::new(&ct, reader, &format!("{prefix}.ffn_down"), device)?; + let mlp = Mlp { + ffn_up, + ffn_down, + i_size, + }; + let attn_norm = rms_norm( + ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?, + rms_eps, + )?; + let ffn_norm = rms_norm( + ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?, + rms_eps, + )?; + let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); + let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); + layers.push(LayerWeights { + attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?, + attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?, + attn_norm, + ffn_norm, + mlp, + n_head: head_count, + n_kv_head: head_count_kv, + head_dim: embedding_length / head_count, + cos: cos.clone(), + sin: sin.clone(), + neg_inf: neg_inf.clone(), + kv_cache: None, + span_attn, + span_rot, + }) + } + let span = tracing::span!(tracing::Level::TRACE, "model"); + let span_output = tracing::span!(tracing::Level::TRACE, "output"); + Ok(Self { + tok_embeddings: Embedding::new(tok_embeddings, embedding_length), + layers, + output_norm, + output, + masks: HashMap::new(), + span, + span_output, + }) + } + + fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> { + if let Some(mask) = self.masks.get(&t) { + Ok(mask.clone()) + } else { + let mask: Vec<_> = (0..t) + .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) + .collect(); + let mask = Tensor::from_slice(&mask, (t, t), device)?; + self.masks.insert(t, mask.clone()); + Ok(mask) + } + } + + pub fn forward(&mut self, xs: &Tensor, index_pos: usize) -> Result<Tensor> { + let (_b_sz, seq_len) = xs.dims2()?; + let mask = if seq_len == 1 { + None + } else { + Some(self.mask(seq_len, xs.device())?) + }; + let _enter = self.span.enter(); + let mut xs = self.tok_embeddings.forward(xs)?; + for layer in self.layers.iter_mut() { + let residual = &xs; + let ys = xs.apply(&layer.attn_norm)?; + let ys = layer.forward_attn(&ys, mask.as_ref(), index_pos)?; + let ys = (ys + residual)?; + let residual = &ys; + let ys = ys.apply(&layer.ffn_norm)?; + let ys = layer.mlp.forward(&ys)?; + xs = (ys + residual)? + } + let xs = xs.apply(&self.output_norm)?.i((.., seq_len - 1, ..))?; + let _enter = self.span_output.enter(); + self.output.forward(&xs) + } +} |