summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorEric Buehler <65165915+EricLBuehler@users.noreply.github.com>2024-07-26 15:32:26 -0400
committerGitHub <noreply@github.com>2024-07-26 21:32:26 +0200
commit0f5cbb08b36a2d962470ec590a2d2bd9770bd12d (patch)
treea5d6911051646e96fc833664b44c530c76fe4416 /candle-transformers
parentddafc61055601002622778b7762c15bd60057c1f (diff)
downloadcandle-0f5cbb08b36a2d962470ec590a2d2bd9770bd12d.tar.gz
candle-0f5cbb08b36a2d962470ec590a2d2bd9770bd12d.tar.bz2
candle-0f5cbb08b36a2d962470ec590a2d2bd9770bd12d.zip
Add support for Llama 3.1 (#2359)
* Add Llama 3.1 rope * Clippy * Format * Clippy * Add support for multiple eos tokens: * Untagged either * Remove either dep and fix settings.json * Make the max positional embeddings configurable
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/beit.rs2
-rw-r--r--candle-transformers/src/models/clip/text_model.rs2
-rw-r--r--candle-transformers/src/models/dinov2.rs2
-rw-r--r--candle-transformers/src/models/dinov2reg4.rs2
-rw-r--r--candle-transformers/src/models/encodec.rs2
-rw-r--r--candle-transformers/src/models/eva2.rs9
-rw-r--r--candle-transformers/src/models/llama.rs112
-rw-r--r--candle-transformers/src/models/llava/config.rs6
-rw-r--r--candle-transformers/src/models/stable_diffusion/attention.rs2
-rw-r--r--candle-transformers/src/models/stable_diffusion/clip.rs2
-rw-r--r--candle-transformers/src/models/stable_diffusion/unet_2d.rs8
-rw-r--r--candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs20
-rw-r--r--candle-transformers/src/models/stable_diffusion/vae.rs4
-rw-r--r--candle-transformers/src/models/t5.rs2
14 files changed, 125 insertions, 50 deletions
diff --git a/candle-transformers/src/models/beit.rs b/candle-transformers/src/models/beit.rs
index 62bdd75a..8f6284a8 100644
--- a/candle-transformers/src/models/beit.rs
+++ b/candle-transformers/src/models/beit.rs
@@ -288,7 +288,7 @@ impl BeitVisionTransformer {
let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?;
let vb_b = vb.pp("blocks");
let blocks = (0..depth)
- .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
+ .map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
patch_embed,
diff --git a/candle-transformers/src/models/clip/text_model.rs b/candle-transformers/src/models/clip/text_model.rs
index 4e4b4c90..51db14ee 100644
--- a/candle-transformers/src/models/clip/text_model.rs
+++ b/candle-transformers/src/models/clip/text_model.rs
@@ -249,7 +249,7 @@ impl ClipEncoder {
let vs = vs.pp("layers");
let mut layers: Vec<ClipEncoderLayer> = Vec::new();
for index in 0..c.num_hidden_layers() {
- let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?;
+ let layer = ClipEncoderLayer::new(vs.pp(index.to_string()), c)?;
layers.push(layer)
}
Ok(ClipEncoder { layers })
diff --git a/candle-transformers/src/models/dinov2.rs b/candle-transformers/src/models/dinov2.rs
index 00e501ce..706dfda0 100644
--- a/candle-transformers/src/models/dinov2.rs
+++ b/candle-transformers/src/models/dinov2.rs
@@ -214,7 +214,7 @@ impl DinoVisionTransformer {
let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?;
let vb_b = vb.pp("blocks");
let blocks = (0..depth)
- .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
+ .map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
patch_embed,
diff --git a/candle-transformers/src/models/dinov2reg4.rs b/candle-transformers/src/models/dinov2reg4.rs
index 6bbe2e24..1d81703c 100644
--- a/candle-transformers/src/models/dinov2reg4.rs
+++ b/candle-transformers/src/models/dinov2reg4.rs
@@ -212,7 +212,7 @@ impl DinoVisionTransformer {
let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?;
let vb_b = vb.pp("blocks");
let blocks = (0..depth)
- .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
+ .map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
patch_embed,
diff --git a/candle-transformers/src/models/encodec.rs b/candle-transformers/src/models/encodec.rs
index 14a85d3e..fb70fb52 100644
--- a/candle-transformers/src/models/encodec.rs
+++ b/candle-transformers/src/models/encodec.rs
@@ -571,7 +571,7 @@ impl<'a> Layer<'a> {
}
fn next(&mut self) -> VarBuilder {
- let vb = self.vb.pp(&self.cnt.to_string());
+ let vb = self.vb.pp(self.cnt.to_string());
self.cnt += 1;
vb
}
diff --git a/candle-transformers/src/models/eva2.rs b/candle-transformers/src/models/eva2.rs
index eb2df4cd..013c385d 100644
--- a/candle-transformers/src/models/eva2.rs
+++ b/candle-transformers/src/models/eva2.rs
@@ -255,14 +255,7 @@ impl EVA2VisionTransformer {
let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?;
let vb_b = vb.pp("blocks");
let blocks = (0..depth)
- .map(|i| {
- Block::new(
- vb_b.pp(&i.to_string()),
- embed_dim,
- num_heads,
- &rot_pos_embed,
- )
- })
+ .map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads, &rot_pos_embed))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
patch_embed,
diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs
index a1f43d35..3681472b 100644
--- a/candle-transformers/src/models/llama.rs
+++ b/candle-transformers/src/models/llama.rs
@@ -1,9 +1,33 @@
use super::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{embedding, Embedding, Module, VarBuilder};
-use std::collections::HashMap;
+use std::{collections::HashMap, f32::consts::PI};
-pub const MAX_SEQ_LEN: usize = 4096;
+pub const DEFAULT_MAX_SEQ_LEN: usize = 4096;
+
+#[derive(Debug, Clone, serde::Deserialize, Default)]
+pub enum Llama3RopeType {
+ #[serde(rename = "llama3")]
+ Llama3,
+ #[default]
+ #[serde(rename = "default")]
+ Default,
+}
+
+#[derive(Debug, Clone, serde::Deserialize, Default)]
+pub struct Llama3RopeConfig {
+ pub factor: f32,
+ pub low_freq_factor: f32,
+ pub high_freq_factor: f32,
+ pub original_max_position_embeddings: usize,
+ pub rope_type: Llama3RopeType,
+}
+#[derive(Debug, Clone, serde::Deserialize)]
+#[serde(untagged)]
+pub enum LlamaEosToks {
+ Single(u32),
+ Multiple(Vec<u32>),
+}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct LlamaConfig {
@@ -17,7 +41,9 @@ pub struct LlamaConfig {
#[serde(default = "default_rope")]
pub rope_theta: f32,
pub bos_token_id: Option<u32>,
- pub eos_token_id: Option<u32>,
+ pub eos_token_id: Option<LlamaEosToks>,
+ pub rope_scaling: Option<Llama3RopeConfig>,
+ pub max_position_embeddings: usize,
}
impl LlamaConfig {
@@ -44,6 +70,8 @@ impl LlamaConfig {
use_flash_attn,
bos_token_id: self.bos_token_id,
eos_token_id: self.eos_token_id,
+ rope_scaling: self.rope_scaling,
+ max_position_embeddings: self.max_position_embeddings,
}
}
}
@@ -60,7 +88,9 @@ pub struct Config {
pub rms_norm_eps: f64,
pub rope_theta: f32,
pub bos_token_id: Option<u32>,
- pub eos_token_id: Option<u32>,
+ pub eos_token_id: Option<LlamaEosToks>,
+ pub rope_scaling: Option<Llama3RopeConfig>,
+ pub max_position_embeddings: usize,
}
impl Config {
@@ -77,6 +107,8 @@ impl Config {
rope_theta: 10_000.0,
bos_token_id: None,
eos_token_id: None,
+ rope_scaling: None,
+ max_position_embeddings: DEFAULT_MAX_SEQ_LEN,
}
}
@@ -93,6 +125,8 @@ impl Config {
rope_theta: 10_000.0,
bos_token_id: None,
eos_token_id: None,
+ rope_scaling: None,
+ max_position_embeddings: DEFAULT_MAX_SEQ_LEN,
}
}
}
@@ -107,18 +141,54 @@ pub struct Cache {
device: Device,
}
+fn calculate_default_inv_freq(cfg: &Config) -> Vec<f32> {
+ let head_dim = cfg.hidden_size / cfg.num_attention_heads;
+ (0..head_dim)
+ .step_by(2)
+ .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
+ .collect()
+}
+
impl Cache {
pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> {
// precompute freqs_cis
- let n_elem = config.hidden_size / config.num_attention_heads;
- let theta: Vec<_> = (0..n_elem)
- .step_by(2)
- .map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32))
- .collect();
- let theta = Tensor::new(theta.as_slice(), device)?;
- let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
+ let theta = match &config.rope_scaling {
+ None
+ | Some(Llama3RopeConfig {
+ rope_type: Llama3RopeType::Default,
+ ..
+ }) => calculate_default_inv_freq(config),
+ Some(rope_scaling) => {
+ let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
+ / rope_scaling.low_freq_factor;
+ let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
+ / rope_scaling.high_freq_factor;
+
+ calculate_default_inv_freq(config)
+ .into_iter()
+ .map(|freq| {
+ let wavelen = 2. * PI / freq;
+ if wavelen < high_freq_wavelen {
+ freq
+ } else if wavelen > low_freq_wavelen {
+ freq / rope_scaling.factor
+ } else {
+ let smooth = (rope_scaling.original_max_position_embeddings as f32
+ / wavelen
+ - rope_scaling.low_freq_factor)
+ / (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor);
+ (1. - smooth) * freq / rope_scaling.factor + smooth * freq
+ }
+ })
+ .collect::<Vec<_>>()
+ }
+ };
+
+ let theta = Tensor::new(theta, device)?;
+
+ let idx_theta = Tensor::arange(0, config.max_position_embeddings as u32, device)?
.to_dtype(DType::F32)?
- .reshape((MAX_SEQ_LEN, 1))?
+ .reshape((config.max_position_embeddings, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
// This is different from the paper, see:
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
@@ -160,6 +230,7 @@ struct CausalSelfAttention {
use_flash_attn: bool,
span: tracing::Span,
span_rot: tracing::Span,
+ max_position_embeddings: usize,
}
#[cfg(feature = "flash-attn")]
@@ -220,15 +291,23 @@ impl CausalSelfAttention {
k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
let k_seq_len = k.dims()[1];
- if k_seq_len > MAX_SEQ_LEN {
+ if k_seq_len > self.max_position_embeddings {
k = k
- .narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
+ .narrow(
+ D::Minus1,
+ k_seq_len - self.max_position_embeddings,
+ self.max_position_embeddings,
+ )?
.contiguous()?
}
let v_seq_len = v.dims()[1];
- if v_seq_len > 2 * MAX_SEQ_LEN {
+ if v_seq_len > 2 * self.max_position_embeddings {
v = v
- .narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
+ .narrow(
+ D::Minus1,
+ v_seq_len - self.max_position_embeddings,
+ self.max_position_embeddings,
+ )?
.contiguous()?
}
}
@@ -291,6 +370,7 @@ impl CausalSelfAttention {
use_flash_attn: cfg.use_flash_attn,
span,
span_rot,
+ max_position_embeddings: cfg.max_position_embeddings,
})
}
}
diff --git a/candle-transformers/src/models/llava/config.rs b/candle-transformers/src/models/llava/config.rs
index d2d47003..5dca6870 100644
--- a/candle-transformers/src/models/llava/config.rs
+++ b/candle-transformers/src/models/llava/config.rs
@@ -2,7 +2,7 @@ use std::collections::HashMap;
use crate::models::{
clip::{text_model::Activation, vision_model::ClipVisionConfig},
- llama::Config,
+ llama::{Config, LlamaEosToks},
};
use serde::{Deserialize, Serialize};
@@ -73,8 +73,10 @@ impl LLaVAConfig {
rms_norm_eps: self.rms_norm_eps as f64,
rope_theta: self.rope_theta,
bos_token_id: Some(self.bos_token_id as u32),
- eos_token_id: Some(self.eos_token_id as u32),
+ eos_token_id: Some(LlamaEosToks::Single(self.eos_token_id as u32)),
use_flash_attn: false,
+ rope_scaling: None, // Assume we don't have LLaVA for Llama 3.1
+ max_position_embeddings: self.max_position_embeddings,
}
}
}
diff --git a/candle-transformers/src/models/stable_diffusion/attention.rs b/candle-transformers/src/models/stable_diffusion/attention.rs
index 4d5a7c47..5cc59e82 100644
--- a/candle-transformers/src/models/stable_diffusion/attention.rs
+++ b/candle-transformers/src/models/stable_diffusion/attention.rs
@@ -358,7 +358,7 @@ impl SpatialTransformer {
let vs_tb = vs.pp("transformer_blocks");
for index in 0..config.depth {
let tb = BasicTransformerBlock::new(
- vs_tb.pp(&index.to_string()),
+ vs_tb.pp(index.to_string()),
inner_dim,
n_heads,
d_head,
diff --git a/candle-transformers/src/models/stable_diffusion/clip.rs b/candle-transformers/src/models/stable_diffusion/clip.rs
index 20e8ceac..5254818e 100644
--- a/candle-transformers/src/models/stable_diffusion/clip.rs
+++ b/candle-transformers/src/models/stable_diffusion/clip.rs
@@ -322,7 +322,7 @@ impl ClipEncoder {
let vs = vs.pp("layers");
let mut layers: Vec<ClipEncoderLayer> = Vec::new();
for index in 0..c.num_hidden_layers {
- let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?;
+ let layer = ClipEncoderLayer::new(vs.pp(index.to_string()), c)?;
layers.push(layer)
}
Ok(ClipEncoder { layers })
diff --git a/candle-transformers/src/models/stable_diffusion/unet_2d.rs b/candle-transformers/src/models/stable_diffusion/unet_2d.rs
index f23bd425..cbef3316 100644
--- a/candle-transformers/src/models/stable_diffusion/unet_2d.rs
+++ b/candle-transformers/src/models/stable_diffusion/unet_2d.rs
@@ -161,7 +161,7 @@ impl UNet2DConditionModel {
transformer_layers_per_block,
};
let block = CrossAttnDownBlock2D::new(
- vs_db.pp(&i.to_string()),
+ vs_db.pp(i.to_string()),
in_channels,
out_channels,
Some(time_embed_dim),
@@ -171,7 +171,7 @@ impl UNet2DConditionModel {
Ok(UNetDownBlock::CrossAttn(block))
} else {
let block = DownBlock2D::new(
- vs_db.pp(&i.to_string()),
+ vs_db.pp(i.to_string()),
in_channels,
out_channels,
Some(time_embed_dim),
@@ -251,7 +251,7 @@ impl UNet2DConditionModel {
transformer_layers_per_block,
};
let block = CrossAttnUpBlock2D::new(
- vs_ub.pp(&i.to_string()),
+ vs_ub.pp(i.to_string()),
in_channels,
prev_out_channels,
out_channels,
@@ -262,7 +262,7 @@ impl UNet2DConditionModel {
Ok(UNetUpBlock::CrossAttn(block))
} else {
let block = UpBlock2D::new(
- vs_ub.pp(&i.to_string()),
+ vs_ub.pp(i.to_string()),
in_channels,
prev_out_channels,
out_channels,
diff --git a/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs b/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs
index 18448427..028c51b7 100644
--- a/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs
+++ b/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs
@@ -146,7 +146,7 @@ impl DownEncoderBlock2D {
(0..(config.num_layers))
.map(|i| {
let in_channels = if i == 0 { in_channels } else { out_channels };
- ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
+ ResnetBlock2D::new(vs.pp(i.to_string()), in_channels, conv_cfg)
})
.collect::<Result<Vec<_>>>()?
};
@@ -235,7 +235,7 @@ impl UpDecoderBlock2D {
(0..(config.num_layers))
.map(|i| {
let in_channels = if i == 0 { in_channels } else { out_channels };
- ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
+ ResnetBlock2D::new(vs.pp(i.to_string()), in_channels, conv_cfg)
})
.collect::<Result<Vec<_>>>()?
};
@@ -328,9 +328,9 @@ impl UNetMidBlock2D {
};
let mut attn_resnets = vec![];
for index in 0..config.num_layers {
- let attn = AttentionBlock::new(vs_attns.pp(&index.to_string()), in_channels, attn_cfg)?;
+ let attn = AttentionBlock::new(vs_attns.pp(index.to_string()), in_channels, attn_cfg)?;
let resnet = ResnetBlock2D::new(
- vs_resnets.pp(&(index + 1).to_string()),
+ vs_resnets.pp((index + 1).to_string()),
in_channels,
resnet_cfg,
)?;
@@ -425,7 +425,7 @@ impl UNetMidBlock2DCrossAttn {
let mut attn_resnets = vec![];
for index in 0..config.num_layers {
let attn = SpatialTransformer::new(
- vs_attns.pp(&index.to_string()),
+ vs_attns.pp(index.to_string()),
in_channels,
n_heads,
in_channels / n_heads,
@@ -433,7 +433,7 @@ impl UNetMidBlock2DCrossAttn {
attn_cfg,
)?;
let resnet = ResnetBlock2D::new(
- vs_resnets.pp(&(index + 1).to_string()),
+ vs_resnets.pp((index + 1).to_string()),
in_channels,
resnet_cfg,
)?;
@@ -515,7 +515,7 @@ impl DownBlock2D {
let resnets = (0..config.num_layers)
.map(|i| {
let in_channels = if i == 0 { in_channels } else { out_channels };
- ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
+ ResnetBlock2D::new(vs_resnets.pp(i.to_string()), in_channels, resnet_cfg)
})
.collect::<Result<Vec<_>>>()?;
let downsampler = if config.add_downsample {
@@ -619,7 +619,7 @@ impl CrossAttnDownBlock2D {
let attentions = (0..config.downblock.num_layers)
.map(|i| {
SpatialTransformer::new(
- vs_attn.pp(&i.to_string()),
+ vs_attn.pp(i.to_string()),
out_channels,
n_heads,
out_channels / n_heads,
@@ -724,7 +724,7 @@ impl UpBlock2D {
out_channels
};
let in_channels = resnet_in_channels + res_skip_channels;
- ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
+ ResnetBlock2D::new(vs_resnets.pp(i.to_string()), in_channels, resnet_cfg)
})
.collect::<Result<Vec<_>>>()?;
let upsampler = if config.add_upsample {
@@ -826,7 +826,7 @@ impl CrossAttnUpBlock2D {
let attentions = (0..config.upblock.num_layers)
.map(|i| {
SpatialTransformer::new(
- vs_attn.pp(&i.to_string()),
+ vs_attn.pp(i.to_string()),
out_channels,
n_heads,
out_channels / n_heads,
diff --git a/candle-transformers/src/models/stable_diffusion/vae.rs b/candle-transformers/src/models/stable_diffusion/vae.rs
index 21709afe..670b3f56 100644
--- a/candle-transformers/src/models/stable_diffusion/vae.rs
+++ b/candle-transformers/src/models/stable_diffusion/vae.rs
@@ -80,7 +80,7 @@ impl Encoder {
..Default::default()
};
let down_block = DownEncoderBlock2D::new(
- vs_down_blocks.pp(&index.to_string()),
+ vs_down_blocks.pp(index.to_string()),
in_channels,
out_channels,
cfg,
@@ -222,7 +222,7 @@ impl Decoder {
..Default::default()
};
let up_block = UpDecoderBlock2D::new(
- vs_up_blocks.pp(&index.to_string()),
+ vs_up_blocks.pp(index.to_string()),
in_channels,
out_channels,
cfg,
diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs
index 8a7a8955..21517d64 100644
--- a/candle-transformers/src/models/t5.rs
+++ b/candle-transformers/src/models/t5.rs
@@ -601,7 +601,7 @@ impl T5Block {
None
};
let ff_i = if cross_attn.is_some() { 2 } else { 1 };
- let ff = T5LayerFF::load(vb.pp(&ff_i.to_string()), cfg)?;
+ let ff = T5LayerFF::load(vb.pp(ff_i.to_string()), cfg)?;
Ok(Self {
self_attn,
cross_attn,