diff options
Diffstat (limited to 'candle-transformers/src')
-rw-r--r-- | candle-transformers/src/models/bert.rs | 7 | ||||
-rw-r--r-- | candle-transformers/src/models/bigcode.rs | 7 | ||||
-rw-r--r-- | candle-transformers/src/models/falcon.rs | 7 | ||||
-rw-r--r-- | candle-transformers/src/models/llama.rs | 9 | ||||
-rw-r--r-- | candle-transformers/src/models/mod.rs | 1 | ||||
-rw-r--r-- | candle-transformers/src/models/repvgg.rs | 306 | ||||
-rw-r--r-- | candle-transformers/src/models/whisper/model.rs | 7 |
7 files changed, 313 insertions, 31 deletions
diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index 51c524f5..810f2803 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -1,6 +1,6 @@ use super::with_tracing::{layer_norm, linear, LayerNorm, Linear}; use candle::{DType, Device, Result, Tensor}; -use candle_nn::{Embedding, Module, VarBuilder}; +use candle_nn::{embedding, Embedding, Module, VarBuilder}; use serde::Deserialize; pub const DTYPE: DType = DType::F32; @@ -112,11 +112,6 @@ impl Config { } } -fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} - struct Dropout { #[allow(dead_code)] pr: f64, diff --git a/candle-transformers/src/models/bigcode.rs b/candle-transformers/src/models/bigcode.rs index c4a2d1db..e69f08c8 100644 --- a/candle-transformers/src/models/bigcode.rs +++ b/candle-transformers/src/models/bigcode.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder}; +use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder}; fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> { let weight = vb.get((size2, size1), "weight")?; @@ -11,11 +11,6 @@ fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Line Ok(Linear::new(weight, bias)) } -fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} - fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> { let weight = vb.get(size, "weight")?; let bias = vb.get(size, "bias")?; diff --git a/candle-transformers/src/models/falcon.rs b/candle-transformers/src/models/falcon.rs index 6ede136a..ef5a92fc 100644 --- a/candle-transformers/src/models/falcon.rs +++ b/candle-transformers/src/models/falcon.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, Result, Tensor, D}; -use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder}; +use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder}; const MAX_SEQ_LEN: usize = 5000; @@ -27,11 +27,6 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> { Ok(LayerNorm::new(weight, bias, eps)) } -fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} - // https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py #[derive(Debug)] pub struct Config { diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index 7e8c8920..f003866a 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -1,6 +1,6 @@ use super::with_tracing::{linear_no_bias as linear, Linear}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{Embedding, Module, VarBuilder}; +use candle_nn::{embedding, Embedding, Module, VarBuilder}; use serde::Deserialize; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -136,11 +136,6 @@ impl Cache { } } -fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> { - let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?; - Ok(Embedding::new(embeddings, cfg.hidden_size)) -} - struct RmsNorm { inner: candle_nn::RmsNorm, span: tracing::Span, @@ -409,7 +404,7 @@ impl Llama { } pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { - let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; + let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.num_hidden_layers) diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 94a3bd5b..a60b5a06 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -26,6 +26,7 @@ pub mod quantized_mixformer; pub mod quantized_mpt; pub mod quantized_stable_lm; pub mod quantized_t5; +pub mod repvgg; pub mod resnet; pub mod segment_anything; pub mod stable_diffusion; diff --git a/candle-transformers/src/models/repvgg.rs b/candle-transformers/src/models/repvgg.rs new file mode 100644 index 00000000..34016e5b --- /dev/null +++ b/candle-transformers/src/models/repvgg.rs @@ -0,0 +1,306 @@ +//! RepVGG inference implementation +//! +//! See "RepVGG: Making VGG-style ConvNets Great Again" Ding et al. 2021 +//! https://arxiv.org/abs/2101.03697 + +use candle::{Result, Tensor, D}; +use candle_nn::{ + batch_norm, conv2d_no_bias, linear, BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder, +}; + +const CHANNELS_PER_STAGE: [usize; 5] = [64, 64, 128, 256, 512]; + +#[derive(Clone)] +pub struct Config { + a: f32, + b: f32, + groups: usize, + stages: [usize; 4], +} + +impl Config { + pub fn a0() -> Self { + Self { + a: 0.75, + b: 2.5, + groups: 1, + stages: [2, 4, 14, 1], + } + } + + pub fn a1() -> Self { + Self { + a: 1.0, + b: 2.5, + groups: 1, + stages: [2, 4, 14, 1], + } + } + + pub fn a2() -> Self { + Self { + a: 1.5, + b: 2.75, + groups: 1, + stages: [2, 4, 14, 1], + } + } + + pub fn b0() -> Self { + Self { + a: 1.0, + b: 2.5, + groups: 1, + stages: [4, 6, 16, 1], + } + } + + pub fn b1() -> Self { + Self { + a: 2.0, + b: 4.0, + groups: 1, + stages: [4, 6, 16, 1], + } + } + + pub fn b2() -> Self { + Self { + a: 2.5, + b: 5.0, + groups: 1, + stages: [4, 6, 16, 1], + } + } + + pub fn b3() -> Self { + Self { + a: 3.0, + b: 5.0, + groups: 1, + stages: [4, 6, 16, 1], + } + } + + pub fn b1g4() -> Self { + Self { + a: 2.0, + b: 4.0, + groups: 4, + stages: [4, 6, 16, 1], + } + } + + pub fn b2g4() -> Self { + Self { + a: 2.5, + b: 5.0, + groups: 4, + stages: [4, 6, 16, 1], + } + } + + pub fn b3g4() -> Self { + Self { + a: 3.0, + b: 5.0, + groups: 4, + stages: [4, 6, 16, 1], + } + } +} + +// fuses a convolutional kernel and a batchnorm layer into a convolutional layer +// based on the _fuse_bn_tensor method in timm +// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602 +fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> { + let (gamma, beta) = bn.weight_and_bias().unwrap(); + let mu = bn.running_mean(); + let sigma = (bn.running_var() + bn.eps())?.sqrt(); + let gps = (gamma / sigma)?; + let bias = (beta - mu * &gps)?; + let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?; + + Ok((weights, bias)) +} + +// A RepVGG layer has a different training time and inference time architecture. +// The latter is a simple and efficient equivalent transformation of the former +// realized by a structural reparameterization technique, where 3x3 and 1x1 convolutions +// along with identity branches and batchnorm layers are fused into a single 3x3 convolution. +fn repvgg_layer( + has_identity: bool, + dim: usize, + stride: usize, + in_channels: usize, + out_channels: usize, + groups: usize, + vb: VarBuilder, +) -> Result<Func<'static>> { + let conv2d_cfg = Conv2dConfig { + stride, + groups, + padding: 1, + ..Default::default() + }; + + // read and reparameterize the 1x1 conv and bn into w1 and b1 + // based on https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L543 + + let conv1x1_bn = batch_norm(dim, 1e-5, vb.pp("conv_1x1.bn"))?; + let conv1x1 = conv2d_no_bias( + in_channels, + out_channels, + 1, + conv2d_cfg, + vb.pp("conv_1x1.conv"), + )?; + + let (mut w1, b1) = fuse_conv_bn(conv1x1.weight(), conv1x1_bn)?; + + // resize to 3x3 + w1 = w1.pad_with_zeros(D::Minus1, 1, 1)?; + w1 = w1.pad_with_zeros(D::Minus2, 1, 1)?; + + // read and reparameterize the 3x3 conv and bn into w3 and b3 + let convkxk_bn = batch_norm(dim, 1e-5, vb.pp("conv_kxk.bn"))?; + let conv3x3 = conv2d_no_bias( + in_channels, + out_channels, + 3, + conv2d_cfg, + vb.pp("conv_kxk.conv"), + )?; + + let (w3, b3) = fuse_conv_bn(conv3x3.weight(), convkxk_bn)?; + + let mut w = (w1 + w3)?; + let mut b = (b1 + b3)?; + + // read and reparameterize the identity bn into wi and bi + if has_identity { + let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"))?; + + // create a 3x3 convolution equivalent to the identity branch + let mut weights: Vec<f32> = vec![0.0; conv3x3.weight().elem_count()]; + + // https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L620 + let in_dim = in_channels / groups; + for i in 0..in_channels { + weights[i * in_dim * 3 * 3 + (i % in_dim) * 3 * 3 + 4] = 1.0; + } + + let weights = &Tensor::from_vec(weights, w.shape(), w.device())?; + let (wi, bi) = fuse_conv_bn(weights, identity_bn)?; + + w = (w + wi)?; + b = (b + bi)?; + } + + // create the 3x3 conv equivalent to the sum of 3x3, 1x1 and identity branches + let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg); + + Ok(Func::new(move |xs| { + let xs = xs.apply(&reparam_conv)?.relu()?; + Ok(xs) + })) +} + +// Get the number of output channels per stage taking into account the multipliers +fn output_channels_per_stage(a: f32, b: f32, stage: usize) -> usize { + let channels = CHANNELS_PER_STAGE[stage] as f32; + + match stage { + 0 => std::cmp::min(64, (channels * a) as usize), + 4 => (channels * b) as usize, + _ => (channels * a) as usize, + } +} + +// Each stage is made of layers. The first layer always downsamples with stride 2. +// All but the first layer have a residual connection. +// The G4 variants have a groupwise convolution instead of a dense one on odd layers +// counted across stage boundaries, so we keep track of which layer we are in the +// full model. +fn repvgg_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result<Func<'static>> { + let nlayers = cfg.stages[idx - 1]; + let mut layers = Vec::with_capacity(nlayers); + let prev_layers: usize = cfg.stages[..idx - 1].iter().sum(); + let out_channels_prev = output_channels_per_stage(cfg.a, cfg.b, idx - 1); + let out_channels = output_channels_per_stage(cfg.a, cfg.b, idx); + + for layer_idx in 0..nlayers { + let (has_identity, stride, in_channels) = if layer_idx == 0 { + (false, 2, out_channels_prev) + } else { + (true, 1, out_channels) + }; + + let groups = if (prev_layers + layer_idx) % 2 == 1 { + cfg.groups + } else { + 1 + }; + + layers.push(repvgg_layer( + has_identity, + out_channels, + stride, + in_channels, + out_channels, + groups, + vb.pp(layer_idx), + )?) + } + + Ok(Func::new(move |xs| { + let mut xs = xs.clone(); + for layer in layers.iter() { + xs = xs.apply(layer)? + } + Ok(xs) + })) +} + +// Build a RepVGG model for a given configuration. +fn repvgg_model(config: &Config, nclasses: Option<usize>, vb: VarBuilder) -> Result<Func<'static>> { + let cls = match nclasses { + None => None, + Some(nclasses) => { + let outputs = output_channels_per_stage(config.a, config.b, 4); + let linear = linear(outputs, nclasses, vb.pp("head.fc"))?; + Some(linear) + } + }; + + let stem_dim = output_channels_per_stage(config.a, config.b, 0); + let stem = repvgg_layer(false, stem_dim, 2, 3, stem_dim, 1, vb.pp("stem"))?; + let vb = vb.pp("stages"); + let stage1 = repvgg_stage(config, 1, vb.pp(0))?; + let stage2 = repvgg_stage(config, 2, vb.pp(1))?; + let stage3 = repvgg_stage(config, 3, vb.pp(2))?; + let stage4 = repvgg_stage(config, 4, vb.pp(3))?; + + Ok(Func::new(move |xs| { + let xs = xs + .apply(&stem)? + .apply(&stage1)? + .apply(&stage2)? + .apply(&stage3)? + .apply(&stage4)? + .mean(D::Minus1)? + .mean(D::Minus1)?; + match &cls { + None => Ok(xs), + Some(cls) => xs.apply(cls), + } + })) +} + +pub fn repvgg(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> { + repvgg_model(cfg, Some(nclasses), vb) +} + +pub fn repvgg_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> { + repvgg_model(cfg, None, vb) +} diff --git a/candle-transformers/src/models/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs index 25454ba6..ea2a59b9 100644 --- a/candle-transformers/src/models/whisper/model.rs +++ b/candle-transformers/src/models/whisper/model.rs @@ -1,12 +1,7 @@ use super::Config; use crate::models::with_tracing::{linear, linear_no_bias, Linear}; use candle::{Device, IndexOp, Result, Tensor, D}; -use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; - -fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} +use candle_nn::{embedding, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; fn conv1d( in_channels: usize, |