summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/benches/benchmarks/affine.rs2
-rw-r--r--candle-core/benches/benchmarks/qmatmul.rs4
-rw-r--r--candle-core/benches/benchmarks/unary.rs2
-rw-r--r--candle-core/benches/benchmarks/where_cond.rs6
-rw-r--r--candle-core/src/tensor.rs6
-rw-r--r--candle-examples/Cargo.toml2
-rw-r--r--candle-examples/examples/llama/main.rs30
-rw-r--r--candle-examples/examples/yolo-v3/darknet.rs2
-rw-r--r--candle-nn/src/activation.rs6
-rw-r--r--candle-nn/src/var_builder.rs1
-rw-r--r--candle-transformers/src/models/beit.rs2
-rw-r--r--candle-transformers/src/models/clip/text_model.rs2
-rw-r--r--candle-transformers/src/models/dinov2.rs2
-rw-r--r--candle-transformers/src/models/dinov2reg4.rs2
-rw-r--r--candle-transformers/src/models/encodec.rs2
-rw-r--r--candle-transformers/src/models/eva2.rs9
-rw-r--r--candle-transformers/src/models/llama.rs112
-rw-r--r--candle-transformers/src/models/llava/config.rs6
-rw-r--r--candle-transformers/src/models/stable_diffusion/attention.rs2
-rw-r--r--candle-transformers/src/models/stable_diffusion/clip.rs2
-rw-r--r--candle-transformers/src/models/stable_diffusion/unet_2d.rs8
-rw-r--r--candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs20
-rw-r--r--candle-transformers/src/models/stable_diffusion/vae.rs4
-rw-r--r--candle-transformers/src/models/t5.rs2
24 files changed, 165 insertions, 71 deletions
diff --git a/candle-core/benches/benchmarks/affine.rs b/candle-core/benches/benchmarks/affine.rs
index eded9f57..c1004c6c 100644
--- a/candle-core/benches/benchmarks/affine.rs
+++ b/candle-core/benches/benchmarks/affine.rs
@@ -12,7 +12,7 @@ fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name:
let m = 1024;
let k = 1024;
- let tensor = Tensor::zeros((b, m, k), dtype, &device).unwrap();
+ let tensor = Tensor::zeros((b, m, k), dtype, device).unwrap();
let flops = b * m * k * dtype.size_in_bytes();
diff --git a/candle-core/benches/benchmarks/qmatmul.rs b/candle-core/benches/benchmarks/qmatmul.rs
index ccb136ac..4d34588b 100644
--- a/candle-core/benches/benchmarks/qmatmul.rs
+++ b/candle-core/benches/benchmarks/qmatmul.rs
@@ -7,7 +7,7 @@ use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(matmul: &QMatMul, x: &Tensor) {
- matmul.forward(&x).unwrap();
+ matmul.forward(x).unwrap();
}
fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {
@@ -50,7 +50,7 @@ fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
- for dtype in vec![
+ for dtype in [
GgmlDType::F32,
GgmlDType::F16,
GgmlDType::Q4_0,
diff --git a/candle-core/benches/benchmarks/unary.rs b/candle-core/benches/benchmarks/unary.rs
index a8e0d025..9efd7509 100644
--- a/candle-core/benches/benchmarks/unary.rs
+++ b/candle-core/benches/benchmarks/unary.rs
@@ -12,7 +12,7 @@ fn run_unary_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &
let m = 1024;
let k = 1024;
- let tensor = Tensor::arange(0.0f32, (b * m * k) as f32, &device)
+ let tensor = Tensor::arange(0.0f32, (b * m * k) as f32, device)
.unwrap()
.to_dtype(dtype)
.unwrap()
diff --git a/candle-core/benches/benchmarks/where_cond.rs b/candle-core/benches/benchmarks/where_cond.rs
index c517dcf5..0e91f656 100644
--- a/candle-core/benches/benchmarks/where_cond.rs
+++ b/candle-core/benches/benchmarks/where_cond.rs
@@ -25,9 +25,9 @@ const SIZE: usize = B * M * K;
const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
- let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), &device).unwrap();
- let on_true = Tensor::ones((B, M, K), dtype, &device).unwrap();
- let on_false = Tensor::zeros((B, M, K), dtype, &device).unwrap();
+ let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), device).unwrap();
+ let on_true = Tensor::ones((B, M, K), dtype, device).unwrap();
+ let on_false = Tensor::zeros((B, M, K), dtype, device).unwrap();
let elements = B * M * K;
// E.g. 2 f32 tensors + 1 u8 tensor
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index dd1b44b0..82532f20 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -590,9 +590,9 @@ impl Tensor {
///
/// * `args` - A slice of 1D tensors.
/// * `xy_indexing` - Whether to use xy indexing or ij indexing. If xy is selected, the
- /// first dimension corresponds to the cardinality of the second input and the second
- /// dimension corresponds to the cardinality of the first input. If ij is selected, the
- /// dimensions are in the same order as the cardinality of the inputs.
+ /// first dimension corresponds to the cardinality of the second input and the second
+ /// dimension corresponds to the cardinality of the first input. If ij is selected, the
+ /// dimensions are in the same order as the cardinality of the inputs.
///
/// # Examples
///
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml
index fa5c620a..56e3d535 100644
--- a/candle-examples/Cargo.toml
+++ b/candle-examples/Cargo.toml
@@ -35,7 +35,7 @@ serde = { workspace = true }
serde_json = { workspace = true }
symphonia = { version = "0.5.3", features = ["all"], optional = true }
tokenizers = { workspace = true, features = ["onig"] }
-cpal= { version = "0.15.2", optional = true }
+cpal = { version = "0.15.2", optional = true }
[dev-dependencies]
anyhow = { workspace = true }
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index fa7ce81b..93f1e508 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -32,7 +32,9 @@ enum Which {
V1,
V2,
V3,
+ V31,
V3Instruct,
+ V31Instruct,
#[value(name = "solar-10.7b")]
Solar10_7B,
#[value(name = "tiny-llama-1.1b-chat")]
@@ -133,6 +135,8 @@ fn main() -> Result<()> {
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
+ Which::V31 => "meta-llama/Meta-Llama-3.1-8B".to_string(),
+ Which::V31Instruct => "meta-llama/Meta-Llama-3.1-8B-Instruct".to_string(),
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
});
@@ -146,7 +150,13 @@ fn main() -> Result<()> {
let config = config.into_config(args.use_flash_attn);
let filenames = match args.which {
- Which::V1 | Which::V2 | Which::V3 | Which::V3Instruct | Which::Solar10_7B => {
+ Which::V1
+ | Which::V2
+ | Which::V3
+ | Which::V3Instruct
+ | Which::V31
+ | Which::V31Instruct
+ | Which::Solar10_7B => {
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
}
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
@@ -157,9 +167,11 @@ fn main() -> Result<()> {
(Llama::load(vb, &config)?, tokenizer_filename, cache, config)
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
- let eos_token_id = config
- .eos_token_id
- .or_else(|| tokenizer.token_to_id(EOS_TOKEN));
+ let eos_token_id = config.eos_token_id.or_else(|| {
+ tokenizer
+ .token_to_id(EOS_TOKEN)
+ .map(model::LlamaEosToks::Single)
+ });
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
let mut tokens = tokenizer
.encode(prompt, true)
@@ -217,8 +229,14 @@ fn main() -> Result<()> {
token_generated += 1;
tokens.push(next_token);
- if Some(next_token) == eos_token_id {
- break;
+ match eos_token_id {
+ Some(model::LlamaEosToks::Single(eos_tok_id)) if next_token == eos_tok_id => {
+ break;
+ }
+ Some(model::LlamaEosToks::Multiple(ref eos_ids)) if eos_ids.contains(&next_token) => {
+ break;
+ }
+ _ => (),
}
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
diff --git a/candle-examples/examples/yolo-v3/darknet.rs b/candle-examples/examples/yolo-v3/darknet.rs
index 331e712c..1892acdd 100644
--- a/candle-examples/examples/yolo-v3/darknet.rs
+++ b/candle-examples/examples/yolo-v3/darknet.rs
@@ -272,7 +272,7 @@ impl Darknet {
let mut prev_channels: usize = 3;
for (index, block) in self.blocks.iter().enumerate() {
let channels_and_bl = match block.block_type.as_str() {
- "convolutional" => conv(vb.pp(&index.to_string()), index, prev_channels, block)?,
+ "convolutional" => conv(vb.pp(index.to_string()), index, prev_channels, block)?,
"upsample" => upsample(prev_channels)?,
"shortcut" => shortcut(index, prev_channels, block)?,
"route" => route(index, &blocks, block)?,
diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs
index b9745375..fc1819f5 100644
--- a/candle-nn/src/activation.rs
+++ b/candle-nn/src/activation.rs
@@ -93,9 +93,9 @@ impl candle::Module for PReLU {
/// # Arguments
///
/// * `num_channels` - The number of channels. Use `None` to have as single trainable value and
-/// `Some` for a 1D vector with the appropriate number of channels. When applying the `forward`
-/// function, the input tensor shape `s` should either be one dimension with this number of
-/// channels or if `s.len() >= 2` it should have `s[1]` equal to this number.
+/// `Some` for a 1D vector with the appropriate number of channels. When applying the `forward`
+/// function, the input tensor shape `s` should either be one dimension with this number of
+/// channels or if `s.len() >= 2` it should have `s[1]` equal to this number.
pub fn prelu(num_channels: Option<usize>, vs: crate::VarBuilder) -> Result<PReLU> {
let init_ws = crate::init::Init::Const(0.25);
// When using a scalar weight, the PyTorch encoding is to use a 1d vector of length 1.
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index d6f6214f..f6e6160b 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -264,6 +264,7 @@ impl SimpleBackend for VarMap {
}
}
+#[allow(dead_code)]
pub struct SafeTensorWithRouting<'a> {
routing: HashMap<String, usize>,
safetensors: Vec<SafeTensors<'a>>,
diff --git a/candle-transformers/src/models/beit.rs b/candle-transformers/src/models/beit.rs
index 62bdd75a..8f6284a8 100644
--- a/candle-transformers/src/models/beit.rs
+++ b/candle-transformers/src/models/beit.rs
@@ -288,7 +288,7 @@ impl BeitVisionTransformer {
let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?;
let vb_b = vb.pp("blocks");
let blocks = (0..depth)
- .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
+ .map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
patch_embed,
diff --git a/candle-transformers/src/models/clip/text_model.rs b/candle-transformers/src/models/clip/text_model.rs
index 4e4b4c90..51db14ee 100644
--- a/candle-transformers/src/models/clip/text_model.rs
+++ b/candle-transformers/src/models/clip/text_model.rs
@@ -249,7 +249,7 @@ impl ClipEncoder {
let vs = vs.pp("layers");
let mut layers: Vec<ClipEncoderLayer> = Vec::new();
for index in 0..c.num_hidden_layers() {
- let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?;
+ let layer = ClipEncoderLayer::new(vs.pp(index.to_string()), c)?;
layers.push(layer)
}
Ok(ClipEncoder { layers })
diff --git a/candle-transformers/src/models/dinov2.rs b/candle-transformers/src/models/dinov2.rs
index 00e501ce..706dfda0 100644
--- a/candle-transformers/src/models/dinov2.rs
+++ b/candle-transformers/src/models/dinov2.rs
@@ -214,7 +214,7 @@ impl DinoVisionTransformer {
let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?;
let vb_b = vb.pp("blocks");
let blocks = (0..depth)
- .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
+ .map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
patch_embed,
diff --git a/candle-transformers/src/models/dinov2reg4.rs b/candle-transformers/src/models/dinov2reg4.rs
index 6bbe2e24..1d81703c 100644
--- a/candle-transformers/src/models/dinov2reg4.rs
+++ b/candle-transformers/src/models/dinov2reg4.rs
@@ -212,7 +212,7 @@ impl DinoVisionTransformer {
let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?;
let vb_b = vb.pp("blocks");
let blocks = (0..depth)
- .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
+ .map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
patch_embed,
diff --git a/candle-transformers/src/models/encodec.rs b/candle-transformers/src/models/encodec.rs
index 14a85d3e..fb70fb52 100644
--- a/candle-transformers/src/models/encodec.rs
+++ b/candle-transformers/src/models/encodec.rs
@@ -571,7 +571,7 @@ impl<'a> Layer<'a> {
}
fn next(&mut self) -> VarBuilder {
- let vb = self.vb.pp(&self.cnt.to_string());
+ let vb = self.vb.pp(self.cnt.to_string());
self.cnt += 1;
vb
}
diff --git a/candle-transformers/src/models/eva2.rs b/candle-transformers/src/models/eva2.rs
index eb2df4cd..013c385d 100644
--- a/candle-transformers/src/models/eva2.rs
+++ b/candle-transformers/src/models/eva2.rs
@@ -255,14 +255,7 @@ impl EVA2VisionTransformer {
let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?;
let vb_b = vb.pp("blocks");
let blocks = (0..depth)
- .map(|i| {
- Block::new(
- vb_b.pp(&i.to_string()),
- embed_dim,
- num_heads,
- &rot_pos_embed,
- )
- })
+ .map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads, &rot_pos_embed))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
patch_embed,
diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs
index a1f43d35..3681472b 100644
--- a/candle-transformers/src/models/llama.rs
+++ b/candle-transformers/src/models/llama.rs
@@ -1,9 +1,33 @@
use super::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{embedding, Embedding, Module, VarBuilder};
-use std::collections::HashMap;
+use std::{collections::HashMap, f32::consts::PI};
-pub const MAX_SEQ_LEN: usize = 4096;
+pub const DEFAULT_MAX_SEQ_LEN: usize = 4096;
+
+#[derive(Debug, Clone, serde::Deserialize, Default)]
+pub enum Llama3RopeType {
+ #[serde(rename = "llama3")]
+ Llama3,
+ #[default]
+ #[serde(rename = "default")]
+ Default,
+}
+
+#[derive(Debug, Clone, serde::Deserialize, Default)]
+pub struct Llama3RopeConfig {
+ pub factor: f32,
+ pub low_freq_factor: f32,
+ pub high_freq_factor: f32,
+ pub original_max_position_embeddings: usize,
+ pub rope_type: Llama3RopeType,
+}
+#[derive(Debug, Clone, serde::Deserialize)]
+#[serde(untagged)]
+pub enum LlamaEosToks {
+ Single(u32),
+ Multiple(Vec<u32>),
+}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct LlamaConfig {
@@ -17,7 +41,9 @@ pub struct LlamaConfig {
#[serde(default = "default_rope")]
pub rope_theta: f32,
pub bos_token_id: Option<u32>,
- pub eos_token_id: Option<u32>,
+ pub eos_token_id: Option<LlamaEosToks>,
+ pub rope_scaling: Option<Llama3RopeConfig>,
+ pub max_position_embeddings: usize,
}
impl LlamaConfig {
@@ -44,6 +70,8 @@ impl LlamaConfig {
use_flash_attn,
bos_token_id: self.bos_token_id,
eos_token_id: self.eos_token_id,
+ rope_scaling: self.rope_scaling,
+ max_position_embeddings: self.max_position_embeddings,
}
}
}
@@ -60,7 +88,9 @@ pub struct Config {
pub rms_norm_eps: f64,
pub rope_theta: f32,
pub bos_token_id: Option<u32>,
- pub eos_token_id: Option<u32>,
+ pub eos_token_id: Option<LlamaEosToks>,
+ pub rope_scaling: Option<Llama3RopeConfig>,
+ pub max_position_embeddings: usize,
}
impl Config {
@@ -77,6 +107,8 @@ impl Config {
rope_theta: 10_000.0,
bos_token_id: None,
eos_token_id: None,
+ rope_scaling: None,
+ max_position_embeddings: DEFAULT_MAX_SEQ_LEN,
}
}
@@ -93,6 +125,8 @@ impl Config {
rope_theta: 10_000.0,
bos_token_id: None,
eos_token_id: None,
+ rope_scaling: None,
+ max_position_embeddings: DEFAULT_MAX_SEQ_LEN,
}
}
}
@@ -107,18 +141,54 @@ pub struct Cache {
device: Device,
}
+fn calculate_default_inv_freq(cfg: &Config) -> Vec<f32> {
+ let head_dim = cfg.hidden_size / cfg.num_attention_heads;
+ (0..head_dim)
+ .step_by(2)
+ .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
+ .collect()
+}
+
impl Cache {
pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> {
// precompute freqs_cis
- let n_elem = config.hidden_size / config.num_attention_heads;
- let theta: Vec<_> = (0..n_elem)
- .step_by(2)
- .map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32))
- .collect();
- let theta = Tensor::new(theta.as_slice(), device)?;
- let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
+ let theta = match &config.rope_scaling {
+ None
+ | Some(Llama3RopeConfig {
+ rope_type: Llama3RopeType::Default,
+ ..
+ }) => calculate_default_inv_freq(config),
+ Some(rope_scaling) => {
+ let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
+ / rope_scaling.low_freq_factor;
+ let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
+ / rope_scaling.high_freq_factor;
+
+ calculate_default_inv_freq(config)
+ .into_iter()
+ .map(|freq| {
+ let wavelen = 2. * PI / freq;
+ if wavelen < high_freq_wavelen {
+ freq
+ } else if wavelen > low_freq_wavelen {
+ freq / rope_scaling.factor
+ } else {
+ let smooth = (rope_scaling.original_max_position_embeddings as f32
+ / wavelen
+ - rope_scaling.low_freq_factor)
+ / (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor);
+ (1. - smooth) * freq / rope_scaling.factor + smooth * freq
+ }
+ })
+ .collect::<Vec<_>>()
+ }
+ };
+
+ let theta = Tensor::new(theta, device)?;
+
+ let idx_theta = Tensor::arange(0, config.max_position_embeddings as u32, device)?
.to_dtype(DType::F32)?
- .reshape((MAX_SEQ_LEN, 1))?
+ .reshape((config.max_position_embeddings, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
// This is different from the paper, see:
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
@@ -160,6 +230,7 @@ struct CausalSelfAttention {
use_flash_attn: bool,
span: tracing::Span,
span_rot: tracing::Span,
+ max_position_embeddings: usize,
}
#[cfg(feature = "flash-attn")]
@@ -220,15 +291,23 @@ impl CausalSelfAttention {
k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
let k_seq_len = k.dims()[1];
- if k_seq_len > MAX_SEQ_LEN {
+ if k_seq_len > self.max_position_embeddings {
k = k
- .narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
+ .narrow(
+ D::Minus1,
+ k_seq_len - self.max_position_embeddings,
+ self.max_position_embeddings,
+ )?
.contiguous()?
}
let v_seq_len = v.dims()[1];
- if v_seq_len > 2 * MAX_SEQ_LEN {
+ if v_seq_len > 2 * self.max_position_embeddings {
v = v
- .narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
+ .narrow(
+ D::Minus1,
+ v_seq_len - self.max_position_embeddings,
+ self.max_position_embeddings,
+ )?
.contiguous()?
}
}
@@ -291,6 +370,7 @@ impl CausalSelfAttention {
use_flash_attn: cfg.use_flash_attn,
span,
span_rot,
+ max_position_embeddings: cfg.max_position_embeddings,
})
}
}
diff --git a/candle-transformers/src/models/llava/config.rs b/candle-transformers/src/models/llava/config.rs
index d2d47003..5dca6870 100644
--- a/candle-transformers/src/models/llava/config.rs
+++ b/candle-transformers/src/models/llava/config.rs
@@ -2,7 +2,7 @@ use std::collections::HashMap;
use crate::models::{
clip::{text_model::Activation, vision_model::ClipVisionConfig},
- llama::Config,
+ llama::{Config, LlamaEosToks},
};
use serde::{Deserialize, Serialize};
@@ -73,8 +73,10 @@ impl LLaVAConfig {
rms_norm_eps: self.rms_norm_eps as f64,
rope_theta: self.rope_theta,
bos_token_id: Some(self.bos_token_id as u32),
- eos_token_id: Some(self.eos_token_id as u32),
+ eos_token_id: Some(LlamaEosToks::Single(self.eos_token_id as u32)),
use_flash_attn: false,
+ rope_scaling: None, // Assume we don't have LLaVA for Llama 3.1
+ max_position_embeddings: self.max_position_embeddings,
}
}
}
diff --git a/candle-transformers/src/models/stable_diffusion/attention.rs b/candle-transformers/src/models/stable_diffusion/attention.rs
index 4d5a7c47..5cc59e82 100644
--- a/candle-transformers/src/models/stable_diffusion/attention.rs
+++ b/candle-transformers/src/models/stable_diffusion/attention.rs
@@ -358,7 +358,7 @@ impl SpatialTransformer {
let vs_tb = vs.pp("transformer_blocks");
for index in 0..config.depth {
let tb = BasicTransformerBlock::new(
- vs_tb.pp(&index.to_string()),
+ vs_tb.pp(index.to_string()),
inner_dim,
n_heads,
d_head,
diff --git a/candle-transformers/src/models/stable_diffusion/clip.rs b/candle-transformers/src/models/stable_diffusion/clip.rs
index 20e8ceac..5254818e 100644
--- a/candle-transformers/src/models/stable_diffusion/clip.rs
+++ b/candle-transformers/src/models/stable_diffusion/clip.rs
@@ -322,7 +322,7 @@ impl ClipEncoder {
let vs = vs.pp("layers");
let mut layers: Vec<ClipEncoderLayer> = Vec::new();
for index in 0..c.num_hidden_layers {
- let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?;
+ let layer = ClipEncoderLayer::new(vs.pp(index.to_string()), c)?;
layers.push(layer)
}
Ok(ClipEncoder { layers })
diff --git a/candle-transformers/src/models/stable_diffusion/unet_2d.rs b/candle-transformers/src/models/stable_diffusion/unet_2d.rs
index f23bd425..cbef3316 100644
--- a/candle-transformers/src/models/stable_diffusion/unet_2d.rs
+++ b/candle-transformers/src/models/stable_diffusion/unet_2d.rs
@@ -161,7 +161,7 @@ impl UNet2DConditionModel {
transformer_layers_per_block,
};
let block = CrossAttnDownBlock2D::new(
- vs_db.pp(&i.to_string()),
+ vs_db.pp(i.to_string()),
in_channels,
out_channels,
Some(time_embed_dim),
@@ -171,7 +171,7 @@ impl UNet2DConditionModel {
Ok(UNetDownBlock::CrossAttn(block))
} else {
let block = DownBlock2D::new(
- vs_db.pp(&i.to_string()),
+ vs_db.pp(i.to_string()),
in_channels,
out_channels,
Some(time_embed_dim),
@@ -251,7 +251,7 @@ impl UNet2DConditionModel {
transformer_layers_per_block,
};
let block = CrossAttnUpBlock2D::new(
- vs_ub.pp(&i.to_string()),
+ vs_ub.pp(i.to_string()),
in_channels,
prev_out_channels,
out_channels,
@@ -262,7 +262,7 @@ impl UNet2DConditionModel {
Ok(UNetUpBlock::CrossAttn(block))
} else {
let block = UpBlock2D::new(
- vs_ub.pp(&i.to_string()),
+ vs_ub.pp(i.to_string()),
in_channels,
prev_out_channels,
out_channels,
diff --git a/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs b/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs
index 18448427..028c51b7 100644
--- a/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs
+++ b/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs
@@ -146,7 +146,7 @@ impl DownEncoderBlock2D {
(0..(config.num_layers))
.map(|i| {
let in_channels = if i == 0 { in_channels } else { out_channels };
- ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
+ ResnetBlock2D::new(vs.pp(i.to_string()), in_channels, conv_cfg)
})
.collect::<Result<Vec<_>>>()?
};
@@ -235,7 +235,7 @@ impl UpDecoderBlock2D {
(0..(config.num_layers))
.map(|i| {
let in_channels = if i == 0 { in_channels } else { out_channels };
- ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
+ ResnetBlock2D::new(vs.pp(i.to_string()), in_channels, conv_cfg)
})
.collect::<Result<Vec<_>>>()?
};
@@ -328,9 +328,9 @@ impl UNetMidBlock2D {
};
let mut attn_resnets = vec![];
for index in 0..config.num_layers {
- let attn = AttentionBlock::new(vs_attns.pp(&index.to_string()), in_channels, attn_cfg)?;
+ let attn = AttentionBlock::new(vs_attns.pp(index.to_string()), in_channels, attn_cfg)?;
let resnet = ResnetBlock2D::new(
- vs_resnets.pp(&(index + 1).to_string()),
+ vs_resnets.pp((index + 1).to_string()),
in_channels,
resnet_cfg,
)?;
@@ -425,7 +425,7 @@ impl UNetMidBlock2DCrossAttn {
let mut attn_resnets = vec![];
for index in 0..config.num_layers {
let attn = SpatialTransformer::new(
- vs_attns.pp(&index.to_string()),
+ vs_attns.pp(index.to_string()),
in_channels,
n_heads,
in_channels / n_heads,
@@ -433,7 +433,7 @@ impl UNetMidBlock2DCrossAttn {
attn_cfg,
)?;
let resnet = ResnetBlock2D::new(
- vs_resnets.pp(&(index + 1).to_string()),
+ vs_resnets.pp((index + 1).to_string()),
in_channels,
resnet_cfg,
)?;
@@ -515,7 +515,7 @@ impl DownBlock2D {
let resnets = (0..config.num_layers)
.map(|i| {
let in_channels = if i == 0 { in_channels } else { out_channels };
- ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
+ ResnetBlock2D::new(vs_resnets.pp(i.to_string()), in_channels, resnet_cfg)
})
.collect::<Result<Vec<_>>>()?;
let downsampler = if config.add_downsample {
@@ -619,7 +619,7 @@ impl CrossAttnDownBlock2D {
let attentions = (0..config.downblock.num_layers)
.map(|i| {
SpatialTransformer::new(
- vs_attn.pp(&i.to_string()),
+ vs_attn.pp(i.to_string()),
out_channels,
n_heads,
out_channels / n_heads,
@@ -724,7 +724,7 @@ impl UpBlock2D {
out_channels
};
let in_channels = resnet_in_channels + res_skip_channels;
- ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
+ ResnetBlock2D::new(vs_resnets.pp(i.to_string()), in_channels, resnet_cfg)
})
.collect::<Result<Vec<_>>>()?;
let upsampler = if config.add_upsample {
@@ -826,7 +826,7 @@ impl CrossAttnUpBlock2D {
let attentions = (0..config.upblock.num_layers)
.map(|i| {
SpatialTransformer::new(
- vs_attn.pp(&i.to_string()),
+ vs_attn.pp(i.to_string()),
out_channels,
n_heads,
out_channels / n_heads,
diff --git a/candle-transformers/src/models/stable_diffusion/vae.rs b/candle-transformers/src/models/stable_diffusion/vae.rs
index 21709afe..670b3f56 100644
--- a/candle-transformers/src/models/stable_diffusion/vae.rs
+++ b/candle-transformers/src/models/stable_diffusion/vae.rs
@@ -80,7 +80,7 @@ impl Encoder {
..Default::default()
};
let down_block = DownEncoderBlock2D::new(
- vs_down_blocks.pp(&index.to_string()),
+ vs_down_blocks.pp(index.to_string()),
in_channels,
out_channels,
cfg,
@@ -222,7 +222,7 @@ impl Decoder {
..Default::default()
};
let up_block = UpDecoderBlock2D::new(
- vs_up_blocks.pp(&index.to_string()),
+ vs_up_blocks.pp(index.to_string()),
in_channels,
out_channels,
cfg,
diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs
index 8a7a8955..21517d64 100644
--- a/candle-transformers/src/models/t5.rs
+++ b/candle-transformers/src/models/t5.rs
@@ -601,7 +601,7 @@ impl T5Block {
None
};
let ff_i = if cross_attn.is_some() { 2 } else { 1 };
- let ff = T5LayerFF::load(vb.pp(&ff_i.to_string()), cfg)?;
+ let ff = T5LayerFF::load(vb.pp(ff_i.to_string()), cfg)?;
Ok(Self {
self_attn,
cross_attn,