diff options
author | Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> | 2024-07-26 15:32:26 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-07-26 21:32:26 +0200 |
commit | 0f5cbb08b36a2d962470ec590a2d2bd9770bd12d (patch) | |
tree | a5d6911051646e96fc833664b44c530c76fe4416 /candle-transformers | |
parent | ddafc61055601002622778b7762c15bd60057c1f (diff) | |
download | candle-0f5cbb08b36a2d962470ec590a2d2bd9770bd12d.tar.gz candle-0f5cbb08b36a2d962470ec590a2d2bd9770bd12d.tar.bz2 candle-0f5cbb08b36a2d962470ec590a2d2bd9770bd12d.zip |
Add support for Llama 3.1 (#2359)
* Add Llama 3.1 rope
* Clippy
* Format
* Clippy
* Add support for multiple eos tokens:
* Untagged either
* Remove either dep and fix settings.json
* Make the max positional embeddings configurable
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/beit.rs | 2 | ||||
-rw-r--r-- | candle-transformers/src/models/clip/text_model.rs | 2 | ||||
-rw-r--r-- | candle-transformers/src/models/dinov2.rs | 2 | ||||
-rw-r--r-- | candle-transformers/src/models/dinov2reg4.rs | 2 | ||||
-rw-r--r-- | candle-transformers/src/models/encodec.rs | 2 | ||||
-rw-r--r-- | candle-transformers/src/models/eva2.rs | 9 | ||||
-rw-r--r-- | candle-transformers/src/models/llama.rs | 112 | ||||
-rw-r--r-- | candle-transformers/src/models/llava/config.rs | 6 | ||||
-rw-r--r-- | candle-transformers/src/models/stable_diffusion/attention.rs | 2 | ||||
-rw-r--r-- | candle-transformers/src/models/stable_diffusion/clip.rs | 2 | ||||
-rw-r--r-- | candle-transformers/src/models/stable_diffusion/unet_2d.rs | 8 | ||||
-rw-r--r-- | candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs | 20 | ||||
-rw-r--r-- | candle-transformers/src/models/stable_diffusion/vae.rs | 4 | ||||
-rw-r--r-- | candle-transformers/src/models/t5.rs | 2 |
14 files changed, 125 insertions, 50 deletions
diff --git a/candle-transformers/src/models/beit.rs b/candle-transformers/src/models/beit.rs index 62bdd75a..8f6284a8 100644 --- a/candle-transformers/src/models/beit.rs +++ b/candle-transformers/src/models/beit.rs @@ -288,7 +288,7 @@ impl BeitVisionTransformer { let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?; let vb_b = vb.pp("blocks"); let blocks = (0..depth) - .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads)) + .map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads)) .collect::<Result<Vec<_>>>()?; Ok(Self { patch_embed, diff --git a/candle-transformers/src/models/clip/text_model.rs b/candle-transformers/src/models/clip/text_model.rs index 4e4b4c90..51db14ee 100644 --- a/candle-transformers/src/models/clip/text_model.rs +++ b/candle-transformers/src/models/clip/text_model.rs @@ -249,7 +249,7 @@ impl ClipEncoder { let vs = vs.pp("layers"); let mut layers: Vec<ClipEncoderLayer> = Vec::new(); for index in 0..c.num_hidden_layers() { - let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?; + let layer = ClipEncoderLayer::new(vs.pp(index.to_string()), c)?; layers.push(layer) } Ok(ClipEncoder { layers }) diff --git a/candle-transformers/src/models/dinov2.rs b/candle-transformers/src/models/dinov2.rs index 00e501ce..706dfda0 100644 --- a/candle-transformers/src/models/dinov2.rs +++ b/candle-transformers/src/models/dinov2.rs @@ -214,7 +214,7 @@ impl DinoVisionTransformer { let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?; let vb_b = vb.pp("blocks"); let blocks = (0..depth) - .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads)) + .map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads)) .collect::<Result<Vec<_>>>()?; Ok(Self { patch_embed, diff --git a/candle-transformers/src/models/dinov2reg4.rs b/candle-transformers/src/models/dinov2reg4.rs index 6bbe2e24..1d81703c 100644 --- a/candle-transformers/src/models/dinov2reg4.rs +++ b/candle-transformers/src/models/dinov2reg4.rs @@ -212,7 +212,7 @@ impl DinoVisionTransformer { let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?; let vb_b = vb.pp("blocks"); let blocks = (0..depth) - .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads)) + .map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads)) .collect::<Result<Vec<_>>>()?; Ok(Self { patch_embed, diff --git a/candle-transformers/src/models/encodec.rs b/candle-transformers/src/models/encodec.rs index 14a85d3e..fb70fb52 100644 --- a/candle-transformers/src/models/encodec.rs +++ b/candle-transformers/src/models/encodec.rs @@ -571,7 +571,7 @@ impl<'a> Layer<'a> { } fn next(&mut self) -> VarBuilder { - let vb = self.vb.pp(&self.cnt.to_string()); + let vb = self.vb.pp(self.cnt.to_string()); self.cnt += 1; vb } diff --git a/candle-transformers/src/models/eva2.rs b/candle-transformers/src/models/eva2.rs index eb2df4cd..013c385d 100644 --- a/candle-transformers/src/models/eva2.rs +++ b/candle-transformers/src/models/eva2.rs @@ -255,14 +255,7 @@ impl EVA2VisionTransformer { let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?; let vb_b = vb.pp("blocks"); let blocks = (0..depth) - .map(|i| { - Block::new( - vb_b.pp(&i.to_string()), - embed_dim, - num_heads, - &rot_pos_embed, - ) - }) + .map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads, &rot_pos_embed)) .collect::<Result<Vec<_>>>()?; Ok(Self { patch_embed, diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index a1f43d35..3681472b 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -1,9 +1,33 @@ use super::with_tracing::{linear_no_bias as linear, Linear, RmsNorm}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{embedding, Embedding, Module, VarBuilder}; -use std::collections::HashMap; +use std::{collections::HashMap, f32::consts::PI}; -pub const MAX_SEQ_LEN: usize = 4096; +pub const DEFAULT_MAX_SEQ_LEN: usize = 4096; + +#[derive(Debug, Clone, serde::Deserialize, Default)] +pub enum Llama3RopeType { + #[serde(rename = "llama3")] + Llama3, + #[default] + #[serde(rename = "default")] + Default, +} + +#[derive(Debug, Clone, serde::Deserialize, Default)] +pub struct Llama3RopeConfig { + pub factor: f32, + pub low_freq_factor: f32, + pub high_freq_factor: f32, + pub original_max_position_embeddings: usize, + pub rope_type: Llama3RopeType, +} +#[derive(Debug, Clone, serde::Deserialize)] +#[serde(untagged)] +pub enum LlamaEosToks { + Single(u32), + Multiple(Vec<u32>), +} #[derive(Debug, Clone, serde::Deserialize)] pub struct LlamaConfig { @@ -17,7 +41,9 @@ pub struct LlamaConfig { #[serde(default = "default_rope")] pub rope_theta: f32, pub bos_token_id: Option<u32>, - pub eos_token_id: Option<u32>, + pub eos_token_id: Option<LlamaEosToks>, + pub rope_scaling: Option<Llama3RopeConfig>, + pub max_position_embeddings: usize, } impl LlamaConfig { @@ -44,6 +70,8 @@ impl LlamaConfig { use_flash_attn, bos_token_id: self.bos_token_id, eos_token_id: self.eos_token_id, + rope_scaling: self.rope_scaling, + max_position_embeddings: self.max_position_embeddings, } } } @@ -60,7 +88,9 @@ pub struct Config { pub rms_norm_eps: f64, pub rope_theta: f32, pub bos_token_id: Option<u32>, - pub eos_token_id: Option<u32>, + pub eos_token_id: Option<LlamaEosToks>, + pub rope_scaling: Option<Llama3RopeConfig>, + pub max_position_embeddings: usize, } impl Config { @@ -77,6 +107,8 @@ impl Config { rope_theta: 10_000.0, bos_token_id: None, eos_token_id: None, + rope_scaling: None, + max_position_embeddings: DEFAULT_MAX_SEQ_LEN, } } @@ -93,6 +125,8 @@ impl Config { rope_theta: 10_000.0, bos_token_id: None, eos_token_id: None, + rope_scaling: None, + max_position_embeddings: DEFAULT_MAX_SEQ_LEN, } } } @@ -107,18 +141,54 @@ pub struct Cache { device: Device, } +fn calculate_default_inv_freq(cfg: &Config) -> Vec<f32> { + let head_dim = cfg.hidden_size / cfg.num_attention_heads; + (0..head_dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32)) + .collect() +} + impl Cache { pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> { // precompute freqs_cis - let n_elem = config.hidden_size / config.num_attention_heads; - let theta: Vec<_> = (0..n_elem) - .step_by(2) - .map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32)) - .collect(); - let theta = Tensor::new(theta.as_slice(), device)?; - let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? + let theta = match &config.rope_scaling { + None + | Some(Llama3RopeConfig { + rope_type: Llama3RopeType::Default, + .. + }) => calculate_default_inv_freq(config), + Some(rope_scaling) => { + let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32 + / rope_scaling.low_freq_factor; + let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32 + / rope_scaling.high_freq_factor; + + calculate_default_inv_freq(config) + .into_iter() + .map(|freq| { + let wavelen = 2. * PI / freq; + if wavelen < high_freq_wavelen { + freq + } else if wavelen > low_freq_wavelen { + freq / rope_scaling.factor + } else { + let smooth = (rope_scaling.original_max_position_embeddings as f32 + / wavelen + - rope_scaling.low_freq_factor) + / (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor); + (1. - smooth) * freq / rope_scaling.factor + smooth * freq + } + }) + .collect::<Vec<_>>() + } + }; + + let theta = Tensor::new(theta, device)?; + + let idx_theta = Tensor::arange(0, config.max_position_embeddings as u32, device)? .to_dtype(DType::F32)? - .reshape((MAX_SEQ_LEN, 1))? + .reshape((config.max_position_embeddings, 1))? .matmul(&theta.reshape((1, theta.elem_count()))?)?; // This is different from the paper, see: // https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112 @@ -160,6 +230,7 @@ struct CausalSelfAttention { use_flash_attn: bool, span: tracing::Span, span_rot: tracing::Span, + max_position_embeddings: usize, } #[cfg(feature = "flash-attn")] @@ -220,15 +291,23 @@ impl CausalSelfAttention { k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?; v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?; let k_seq_len = k.dims()[1]; - if k_seq_len > MAX_SEQ_LEN { + if k_seq_len > self.max_position_embeddings { k = k - .narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? + .narrow( + D::Minus1, + k_seq_len - self.max_position_embeddings, + self.max_position_embeddings, + )? .contiguous()? } let v_seq_len = v.dims()[1]; - if v_seq_len > 2 * MAX_SEQ_LEN { + if v_seq_len > 2 * self.max_position_embeddings { v = v - .narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? + .narrow( + D::Minus1, + v_seq_len - self.max_position_embeddings, + self.max_position_embeddings, + )? .contiguous()? } } @@ -291,6 +370,7 @@ impl CausalSelfAttention { use_flash_attn: cfg.use_flash_attn, span, span_rot, + max_position_embeddings: cfg.max_position_embeddings, }) } } diff --git a/candle-transformers/src/models/llava/config.rs b/candle-transformers/src/models/llava/config.rs index d2d47003..5dca6870 100644 --- a/candle-transformers/src/models/llava/config.rs +++ b/candle-transformers/src/models/llava/config.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use crate::models::{ clip::{text_model::Activation, vision_model::ClipVisionConfig}, - llama::Config, + llama::{Config, LlamaEosToks}, }; use serde::{Deserialize, Serialize}; @@ -73,8 +73,10 @@ impl LLaVAConfig { rms_norm_eps: self.rms_norm_eps as f64, rope_theta: self.rope_theta, bos_token_id: Some(self.bos_token_id as u32), - eos_token_id: Some(self.eos_token_id as u32), + eos_token_id: Some(LlamaEosToks::Single(self.eos_token_id as u32)), use_flash_attn: false, + rope_scaling: None, // Assume we don't have LLaVA for Llama 3.1 + max_position_embeddings: self.max_position_embeddings, } } } diff --git a/candle-transformers/src/models/stable_diffusion/attention.rs b/candle-transformers/src/models/stable_diffusion/attention.rs index 4d5a7c47..5cc59e82 100644 --- a/candle-transformers/src/models/stable_diffusion/attention.rs +++ b/candle-transformers/src/models/stable_diffusion/attention.rs @@ -358,7 +358,7 @@ impl SpatialTransformer { let vs_tb = vs.pp("transformer_blocks"); for index in 0..config.depth { let tb = BasicTransformerBlock::new( - vs_tb.pp(&index.to_string()), + vs_tb.pp(index.to_string()), inner_dim, n_heads, d_head, diff --git a/candle-transformers/src/models/stable_diffusion/clip.rs b/candle-transformers/src/models/stable_diffusion/clip.rs index 20e8ceac..5254818e 100644 --- a/candle-transformers/src/models/stable_diffusion/clip.rs +++ b/candle-transformers/src/models/stable_diffusion/clip.rs @@ -322,7 +322,7 @@ impl ClipEncoder { let vs = vs.pp("layers"); let mut layers: Vec<ClipEncoderLayer> = Vec::new(); for index in 0..c.num_hidden_layers { - let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?; + let layer = ClipEncoderLayer::new(vs.pp(index.to_string()), c)?; layers.push(layer) } Ok(ClipEncoder { layers }) diff --git a/candle-transformers/src/models/stable_diffusion/unet_2d.rs b/candle-transformers/src/models/stable_diffusion/unet_2d.rs index f23bd425..cbef3316 100644 --- a/candle-transformers/src/models/stable_diffusion/unet_2d.rs +++ b/candle-transformers/src/models/stable_diffusion/unet_2d.rs @@ -161,7 +161,7 @@ impl UNet2DConditionModel { transformer_layers_per_block, }; let block = CrossAttnDownBlock2D::new( - vs_db.pp(&i.to_string()), + vs_db.pp(i.to_string()), in_channels, out_channels, Some(time_embed_dim), @@ -171,7 +171,7 @@ impl UNet2DConditionModel { Ok(UNetDownBlock::CrossAttn(block)) } else { let block = DownBlock2D::new( - vs_db.pp(&i.to_string()), + vs_db.pp(i.to_string()), in_channels, out_channels, Some(time_embed_dim), @@ -251,7 +251,7 @@ impl UNet2DConditionModel { transformer_layers_per_block, }; let block = CrossAttnUpBlock2D::new( - vs_ub.pp(&i.to_string()), + vs_ub.pp(i.to_string()), in_channels, prev_out_channels, out_channels, @@ -262,7 +262,7 @@ impl UNet2DConditionModel { Ok(UNetUpBlock::CrossAttn(block)) } else { let block = UpBlock2D::new( - vs_ub.pp(&i.to_string()), + vs_ub.pp(i.to_string()), in_channels, prev_out_channels, out_channels, diff --git a/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs b/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs index 18448427..028c51b7 100644 --- a/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs +++ b/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs @@ -146,7 +146,7 @@ impl DownEncoderBlock2D { (0..(config.num_layers)) .map(|i| { let in_channels = if i == 0 { in_channels } else { out_channels }; - ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg) + ResnetBlock2D::new(vs.pp(i.to_string()), in_channels, conv_cfg) }) .collect::<Result<Vec<_>>>()? }; @@ -235,7 +235,7 @@ impl UpDecoderBlock2D { (0..(config.num_layers)) .map(|i| { let in_channels = if i == 0 { in_channels } else { out_channels }; - ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg) + ResnetBlock2D::new(vs.pp(i.to_string()), in_channels, conv_cfg) }) .collect::<Result<Vec<_>>>()? }; @@ -328,9 +328,9 @@ impl UNetMidBlock2D { }; let mut attn_resnets = vec![]; for index in 0..config.num_layers { - let attn = AttentionBlock::new(vs_attns.pp(&index.to_string()), in_channels, attn_cfg)?; + let attn = AttentionBlock::new(vs_attns.pp(index.to_string()), in_channels, attn_cfg)?; let resnet = ResnetBlock2D::new( - vs_resnets.pp(&(index + 1).to_string()), + vs_resnets.pp((index + 1).to_string()), in_channels, resnet_cfg, )?; @@ -425,7 +425,7 @@ impl UNetMidBlock2DCrossAttn { let mut attn_resnets = vec![]; for index in 0..config.num_layers { let attn = SpatialTransformer::new( - vs_attns.pp(&index.to_string()), + vs_attns.pp(index.to_string()), in_channels, n_heads, in_channels / n_heads, @@ -433,7 +433,7 @@ impl UNetMidBlock2DCrossAttn { attn_cfg, )?; let resnet = ResnetBlock2D::new( - vs_resnets.pp(&(index + 1).to_string()), + vs_resnets.pp((index + 1).to_string()), in_channels, resnet_cfg, )?; @@ -515,7 +515,7 @@ impl DownBlock2D { let resnets = (0..config.num_layers) .map(|i| { let in_channels = if i == 0 { in_channels } else { out_channels }; - ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg) + ResnetBlock2D::new(vs_resnets.pp(i.to_string()), in_channels, resnet_cfg) }) .collect::<Result<Vec<_>>>()?; let downsampler = if config.add_downsample { @@ -619,7 +619,7 @@ impl CrossAttnDownBlock2D { let attentions = (0..config.downblock.num_layers) .map(|i| { SpatialTransformer::new( - vs_attn.pp(&i.to_string()), + vs_attn.pp(i.to_string()), out_channels, n_heads, out_channels / n_heads, @@ -724,7 +724,7 @@ impl UpBlock2D { out_channels }; let in_channels = resnet_in_channels + res_skip_channels; - ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg) + ResnetBlock2D::new(vs_resnets.pp(i.to_string()), in_channels, resnet_cfg) }) .collect::<Result<Vec<_>>>()?; let upsampler = if config.add_upsample { @@ -826,7 +826,7 @@ impl CrossAttnUpBlock2D { let attentions = (0..config.upblock.num_layers) .map(|i| { SpatialTransformer::new( - vs_attn.pp(&i.to_string()), + vs_attn.pp(i.to_string()), out_channels, n_heads, out_channels / n_heads, diff --git a/candle-transformers/src/models/stable_diffusion/vae.rs b/candle-transformers/src/models/stable_diffusion/vae.rs index 21709afe..670b3f56 100644 --- a/candle-transformers/src/models/stable_diffusion/vae.rs +++ b/candle-transformers/src/models/stable_diffusion/vae.rs @@ -80,7 +80,7 @@ impl Encoder { ..Default::default() }; let down_block = DownEncoderBlock2D::new( - vs_down_blocks.pp(&index.to_string()), + vs_down_blocks.pp(index.to_string()), in_channels, out_channels, cfg, @@ -222,7 +222,7 @@ impl Decoder { ..Default::default() }; let up_block = UpDecoderBlock2D::new( - vs_up_blocks.pp(&index.to_string()), + vs_up_blocks.pp(index.to_string()), in_channels, out_channels, cfg, diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 8a7a8955..21517d64 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -601,7 +601,7 @@ impl T5Block { None }; let ff_i = if cross_attn.is_some() { 2 } else { 1 }; - let ff = T5LayerFF::load(vb.pp(&ff_i.to_string()), cfg)?; + let ff = T5LayerFF::load(vb.pp(ff_i.to_string()), cfg)?; Ok(Self { self_attn, cross_attn, |