summaryrefslogtreecommitdiff
path: root/candle-transformers/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src')
-rw-r--r--candle-transformers/src/models/bert.rs7
-rw-r--r--candle-transformers/src/models/bigcode.rs7
-rw-r--r--candle-transformers/src/models/falcon.rs7
-rw-r--r--candle-transformers/src/models/llama.rs9
-rw-r--r--candle-transformers/src/models/mod.rs1
-rw-r--r--candle-transformers/src/models/repvgg.rs306
-rw-r--r--candle-transformers/src/models/whisper/model.rs7
7 files changed, 313 insertions, 31 deletions
diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs
index 51c524f5..810f2803 100644
--- a/candle-transformers/src/models/bert.rs
+++ b/candle-transformers/src/models/bert.rs
@@ -1,6 +1,6 @@
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
use candle::{DType, Device, Result, Tensor};
-use candle_nn::{Embedding, Module, VarBuilder};
+use candle_nn::{embedding, Embedding, Module, VarBuilder};
use serde::Deserialize;
pub const DTYPE: DType = DType::F32;
@@ -112,11 +112,6 @@ impl Config {
}
}
-fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
- let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
- Ok(Embedding::new(embeddings, hidden_size))
-}
-
struct Dropout {
#[allow(dead_code)]
pr: f64,
diff --git a/candle-transformers/src/models/bigcode.rs b/candle-transformers/src/models/bigcode.rs
index c4a2d1db..e69f08c8 100644
--- a/candle-transformers/src/models/bigcode.rs
+++ b/candle-transformers/src/models/bigcode.rs
@@ -1,5 +1,5 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
-use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
+use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder};
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
@@ -11,11 +11,6 @@ fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Line
Ok(Linear::new(weight, bias))
}
-fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
- let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
- Ok(Embedding::new(embeddings, hidden_size))
-}
-
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
let weight = vb.get(size, "weight")?;
let bias = vb.get(size, "bias")?;
diff --git a/candle-transformers/src/models/falcon.rs b/candle-transformers/src/models/falcon.rs
index 6ede136a..ef5a92fc 100644
--- a/candle-transformers/src/models/falcon.rs
+++ b/candle-transformers/src/models/falcon.rs
@@ -1,5 +1,5 @@
use candle::{DType, Device, Result, Tensor, D};
-use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
+use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder};
const MAX_SEQ_LEN: usize = 5000;
@@ -27,11 +27,6 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
Ok(LayerNorm::new(weight, bias, eps))
}
-fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
- let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
- Ok(Embedding::new(embeddings, hidden_size))
-}
-
// https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py
#[derive(Debug)]
pub struct Config {
diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs
index 7e8c8920..f003866a 100644
--- a/candle-transformers/src/models/llama.rs
+++ b/candle-transformers/src/models/llama.rs
@@ -1,6 +1,6 @@
use super::with_tracing::{linear_no_bias as linear, Linear};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
-use candle_nn::{Embedding, Module, VarBuilder};
+use candle_nn::{embedding, Embedding, Module, VarBuilder};
use serde::Deserialize;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@@ -136,11 +136,6 @@ impl Cache {
}
}
-fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
- let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
- Ok(Embedding::new(embeddings, cfg.hidden_size))
-}
-
struct RmsNorm {
inner: candle_nn::RmsNorm,
span: tracing::Span,
@@ -409,7 +404,7 @@ impl Llama {
}
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
- let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
+ let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index 94a3bd5b..a60b5a06 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -26,6 +26,7 @@ pub mod quantized_mixformer;
pub mod quantized_mpt;
pub mod quantized_stable_lm;
pub mod quantized_t5;
+pub mod repvgg;
pub mod resnet;
pub mod segment_anything;
pub mod stable_diffusion;
diff --git a/candle-transformers/src/models/repvgg.rs b/candle-transformers/src/models/repvgg.rs
new file mode 100644
index 00000000..34016e5b
--- /dev/null
+++ b/candle-transformers/src/models/repvgg.rs
@@ -0,0 +1,306 @@
+//! RepVGG inference implementation
+//!
+//! See "RepVGG: Making VGG-style ConvNets Great Again" Ding et al. 2021
+//! https://arxiv.org/abs/2101.03697
+
+use candle::{Result, Tensor, D};
+use candle_nn::{
+ batch_norm, conv2d_no_bias, linear, BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
+};
+
+const CHANNELS_PER_STAGE: [usize; 5] = [64, 64, 128, 256, 512];
+
+#[derive(Clone)]
+pub struct Config {
+ a: f32,
+ b: f32,
+ groups: usize,
+ stages: [usize; 4],
+}
+
+impl Config {
+ pub fn a0() -> Self {
+ Self {
+ a: 0.75,
+ b: 2.5,
+ groups: 1,
+ stages: [2, 4, 14, 1],
+ }
+ }
+
+ pub fn a1() -> Self {
+ Self {
+ a: 1.0,
+ b: 2.5,
+ groups: 1,
+ stages: [2, 4, 14, 1],
+ }
+ }
+
+ pub fn a2() -> Self {
+ Self {
+ a: 1.5,
+ b: 2.75,
+ groups: 1,
+ stages: [2, 4, 14, 1],
+ }
+ }
+
+ pub fn b0() -> Self {
+ Self {
+ a: 1.0,
+ b: 2.5,
+ groups: 1,
+ stages: [4, 6, 16, 1],
+ }
+ }
+
+ pub fn b1() -> Self {
+ Self {
+ a: 2.0,
+ b: 4.0,
+ groups: 1,
+ stages: [4, 6, 16, 1],
+ }
+ }
+
+ pub fn b2() -> Self {
+ Self {
+ a: 2.5,
+ b: 5.0,
+ groups: 1,
+ stages: [4, 6, 16, 1],
+ }
+ }
+
+ pub fn b3() -> Self {
+ Self {
+ a: 3.0,
+ b: 5.0,
+ groups: 1,
+ stages: [4, 6, 16, 1],
+ }
+ }
+
+ pub fn b1g4() -> Self {
+ Self {
+ a: 2.0,
+ b: 4.0,
+ groups: 4,
+ stages: [4, 6, 16, 1],
+ }
+ }
+
+ pub fn b2g4() -> Self {
+ Self {
+ a: 2.5,
+ b: 5.0,
+ groups: 4,
+ stages: [4, 6, 16, 1],
+ }
+ }
+
+ pub fn b3g4() -> Self {
+ Self {
+ a: 3.0,
+ b: 5.0,
+ groups: 4,
+ stages: [4, 6, 16, 1],
+ }
+ }
+}
+
+// fuses a convolutional kernel and a batchnorm layer into a convolutional layer
+// based on the _fuse_bn_tensor method in timm
+// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602
+fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
+ let (gamma, beta) = bn.weight_and_bias().unwrap();
+ let mu = bn.running_mean();
+ let sigma = (bn.running_var() + bn.eps())?.sqrt();
+ let gps = (gamma / sigma)?;
+ let bias = (beta - mu * &gps)?;
+ let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?;
+
+ Ok((weights, bias))
+}
+
+// A RepVGG layer has a different training time and inference time architecture.
+// The latter is a simple and efficient equivalent transformation of the former
+// realized by a structural reparameterization technique, where 3x3 and 1x1 convolutions
+// along with identity branches and batchnorm layers are fused into a single 3x3 convolution.
+fn repvgg_layer(
+ has_identity: bool,
+ dim: usize,
+ stride: usize,
+ in_channels: usize,
+ out_channels: usize,
+ groups: usize,
+ vb: VarBuilder,
+) -> Result<Func<'static>> {
+ let conv2d_cfg = Conv2dConfig {
+ stride,
+ groups,
+ padding: 1,
+ ..Default::default()
+ };
+
+ // read and reparameterize the 1x1 conv and bn into w1 and b1
+ // based on https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L543
+
+ let conv1x1_bn = batch_norm(dim, 1e-5, vb.pp("conv_1x1.bn"))?;
+ let conv1x1 = conv2d_no_bias(
+ in_channels,
+ out_channels,
+ 1,
+ conv2d_cfg,
+ vb.pp("conv_1x1.conv"),
+ )?;
+
+ let (mut w1, b1) = fuse_conv_bn(conv1x1.weight(), conv1x1_bn)?;
+
+ // resize to 3x3
+ w1 = w1.pad_with_zeros(D::Minus1, 1, 1)?;
+ w1 = w1.pad_with_zeros(D::Minus2, 1, 1)?;
+
+ // read and reparameterize the 3x3 conv and bn into w3 and b3
+ let convkxk_bn = batch_norm(dim, 1e-5, vb.pp("conv_kxk.bn"))?;
+ let conv3x3 = conv2d_no_bias(
+ in_channels,
+ out_channels,
+ 3,
+ conv2d_cfg,
+ vb.pp("conv_kxk.conv"),
+ )?;
+
+ let (w3, b3) = fuse_conv_bn(conv3x3.weight(), convkxk_bn)?;
+
+ let mut w = (w1 + w3)?;
+ let mut b = (b1 + b3)?;
+
+ // read and reparameterize the identity bn into wi and bi
+ if has_identity {
+ let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"))?;
+
+ // create a 3x3 convolution equivalent to the identity branch
+ let mut weights: Vec<f32> = vec![0.0; conv3x3.weight().elem_count()];
+
+ // https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L620
+ let in_dim = in_channels / groups;
+ for i in 0..in_channels {
+ weights[i * in_dim * 3 * 3 + (i % in_dim) * 3 * 3 + 4] = 1.0;
+ }
+
+ let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
+ let (wi, bi) = fuse_conv_bn(weights, identity_bn)?;
+
+ w = (w + wi)?;
+ b = (b + bi)?;
+ }
+
+ // create the 3x3 conv equivalent to the sum of 3x3, 1x1 and identity branches
+ let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);
+
+ Ok(Func::new(move |xs| {
+ let xs = xs.apply(&reparam_conv)?.relu()?;
+ Ok(xs)
+ }))
+}
+
+// Get the number of output channels per stage taking into account the multipliers
+fn output_channels_per_stage(a: f32, b: f32, stage: usize) -> usize {
+ let channels = CHANNELS_PER_STAGE[stage] as f32;
+
+ match stage {
+ 0 => std::cmp::min(64, (channels * a) as usize),
+ 4 => (channels * b) as usize,
+ _ => (channels * a) as usize,
+ }
+}
+
+// Each stage is made of layers. The first layer always downsamples with stride 2.
+// All but the first layer have a residual connection.
+// The G4 variants have a groupwise convolution instead of a dense one on odd layers
+// counted across stage boundaries, so we keep track of which layer we are in the
+// full model.
+fn repvgg_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
+ let nlayers = cfg.stages[idx - 1];
+ let mut layers = Vec::with_capacity(nlayers);
+ let prev_layers: usize = cfg.stages[..idx - 1].iter().sum();
+ let out_channels_prev = output_channels_per_stage(cfg.a, cfg.b, idx - 1);
+ let out_channels = output_channels_per_stage(cfg.a, cfg.b, idx);
+
+ for layer_idx in 0..nlayers {
+ let (has_identity, stride, in_channels) = if layer_idx == 0 {
+ (false, 2, out_channels_prev)
+ } else {
+ (true, 1, out_channels)
+ };
+
+ let groups = if (prev_layers + layer_idx) % 2 == 1 {
+ cfg.groups
+ } else {
+ 1
+ };
+
+ layers.push(repvgg_layer(
+ has_identity,
+ out_channels,
+ stride,
+ in_channels,
+ out_channels,
+ groups,
+ vb.pp(layer_idx),
+ )?)
+ }
+
+ Ok(Func::new(move |xs| {
+ let mut xs = xs.clone();
+ for layer in layers.iter() {
+ xs = xs.apply(layer)?
+ }
+ Ok(xs)
+ }))
+}
+
+// Build a RepVGG model for a given configuration.
+fn repvgg_model(config: &Config, nclasses: Option<usize>, vb: VarBuilder) -> Result<Func<'static>> {
+ let cls = match nclasses {
+ None => None,
+ Some(nclasses) => {
+ let outputs = output_channels_per_stage(config.a, config.b, 4);
+ let linear = linear(outputs, nclasses, vb.pp("head.fc"))?;
+ Some(linear)
+ }
+ };
+
+ let stem_dim = output_channels_per_stage(config.a, config.b, 0);
+ let stem = repvgg_layer(false, stem_dim, 2, 3, stem_dim, 1, vb.pp("stem"))?;
+ let vb = vb.pp("stages");
+ let stage1 = repvgg_stage(config, 1, vb.pp(0))?;
+ let stage2 = repvgg_stage(config, 2, vb.pp(1))?;
+ let stage3 = repvgg_stage(config, 3, vb.pp(2))?;
+ let stage4 = repvgg_stage(config, 4, vb.pp(3))?;
+
+ Ok(Func::new(move |xs| {
+ let xs = xs
+ .apply(&stem)?
+ .apply(&stage1)?
+ .apply(&stage2)?
+ .apply(&stage3)?
+ .apply(&stage4)?
+ .mean(D::Minus1)?
+ .mean(D::Minus1)?;
+ match &cls {
+ None => Ok(xs),
+ Some(cls) => xs.apply(cls),
+ }
+ }))
+}
+
+pub fn repvgg(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
+ repvgg_model(cfg, Some(nclasses), vb)
+}
+
+pub fn repvgg_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
+ repvgg_model(cfg, None, vb)
+}
diff --git a/candle-transformers/src/models/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs
index 25454ba6..ea2a59b9 100644
--- a/candle-transformers/src/models/whisper/model.rs
+++ b/candle-transformers/src/models/whisper/model.rs
@@ -1,12 +1,7 @@
use super::Config;
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
use candle::{Device, IndexOp, Result, Tensor, D};
-use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
-
-fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
- let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
- Ok(Embedding::new(embeddings, hidden_size))
-}
+use candle_nn::{embedding, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
fn conv1d(
in_channels: usize,