diff options
Diffstat (limited to 'candle-transformers/src/models/efficientnet.rs')
-rw-r--r-- | candle-transformers/src/models/efficientnet.rs | 331 |
1 files changed, 331 insertions, 0 deletions
diff --git a/candle-transformers/src/models/efficientnet.rs b/candle-transformers/src/models/efficientnet.rs new file mode 100644 index 00000000..ab51c76d --- /dev/null +++ b/candle-transformers/src/models/efficientnet.rs @@ -0,0 +1,331 @@ +use candle::{Result, Tensor, D}; +use candle_nn as nn; +use nn::{Module, VarBuilder}; + +// Based on the Python version from torchvision. +// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47 +#[derive(Debug, Clone, Copy)] +pub struct MBConvConfig { + expand_ratio: f64, + kernel: usize, + stride: usize, + input_channels: usize, + out_channels: usize, + num_layers: usize, +} + +fn make_divisible(v: f64, divisor: usize) -> usize { + let min_value = divisor; + let new_v = usize::max( + min_value, + (v + divisor as f64 * 0.5) as usize / divisor * divisor, + ); + if (new_v as f64) < 0.9 * v { + new_v + divisor + } else { + new_v + } +} + +fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> { + let bneck_conf = |e, k, s, i, o, n| { + let input_channels = make_divisible(i as f64 * width_mult, 8); + let out_channels = make_divisible(o as f64 * width_mult, 8); + let num_layers = (n as f64 * depth_mult).ceil() as usize; + MBConvConfig { + expand_ratio: e, + kernel: k, + stride: s, + input_channels, + out_channels, + num_layers, + } + }; + vec![ + bneck_conf(1., 3, 1, 32, 16, 1), + bneck_conf(6., 3, 2, 16, 24, 2), + bneck_conf(6., 5, 2, 24, 40, 2), + bneck_conf(6., 3, 2, 40, 80, 3), + bneck_conf(6., 5, 1, 80, 112, 3), + bneck_conf(6., 5, 2, 112, 192, 4), + bneck_conf(6., 3, 1, 192, 320, 1), + ] +} + +impl MBConvConfig { + pub fn b0() -> Vec<Self> { + bneck_confs(1.0, 1.0) + } + pub fn b1() -> Vec<Self> { + bneck_confs(1.0, 1.1) + } + pub fn b2() -> Vec<Self> { + bneck_confs(1.1, 1.2) + } + pub fn b3() -> Vec<Self> { + bneck_confs(1.2, 1.4) + } + pub fn b4() -> Vec<Self> { + bneck_confs(1.4, 1.8) + } + pub fn b5() -> Vec<Self> { + bneck_confs(1.6, 2.2) + } + pub fn b6() -> Vec<Self> { + bneck_confs(1.8, 2.6) + } + pub fn b7() -> Vec<Self> { + bneck_confs(2.0, 3.1) + } +} + +/// Conv2D with same padding. +#[derive(Debug)] +struct Conv2DSame { + conv2d: nn::Conv2d, + s: usize, + k: usize, +} + +impl Conv2DSame { + fn new( + vb: VarBuilder, + i: usize, + o: usize, + k: usize, + stride: usize, + groups: usize, + bias: bool, + ) -> Result<Self> { + let conv_config = nn::Conv2dConfig { + stride, + groups, + ..Default::default() + }; + let conv2d = if bias { + nn::conv2d(i, o, k, conv_config, vb)? + } else { + nn::conv2d_no_bias(i, o, k, conv_config, vb)? + }; + Ok(Self { + conv2d, + s: stride, + k, + }) + } +} + +impl Module for Conv2DSame { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let s = self.s; + let k = self.k; + let (_, _, ih, iw) = xs.dims4()?; + let oh = (ih + s - 1) / s; + let ow = (iw + s - 1) / s; + let pad_h = usize::max((oh - 1) * s + k - ih, 0); + let pad_w = usize::max((ow - 1) * s + k - iw, 0); + if pad_h > 0 || pad_w > 0 { + let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?; + let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?; + self.conv2d.forward(&xs) + } else { + self.conv2d.forward(xs) + } + } +} + +#[derive(Debug)] +struct ConvNormActivation { + conv2d: Conv2DSame, + bn2d: nn::BatchNorm, + activation: bool, +} + +impl ConvNormActivation { + fn new( + vb: VarBuilder, + i: usize, + o: usize, + k: usize, + stride: usize, + groups: usize, + ) -> Result<Self> { + let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?; + let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?; + Ok(Self { + conv2d, + bn2d, + activation: true, + }) + } + + fn no_activation(self) -> Self { + Self { + activation: false, + ..self + } + } +} + +impl Module for ConvNormActivation { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let xs = self.conv2d.forward(xs)?; + let xs = self.bn2d.forward(&xs)?; + if self.activation { + swish(&xs) + } else { + Ok(xs) + } + } +} + +#[derive(Debug)] +struct SqueezeExcitation { + fc1: Conv2DSame, + fc2: Conv2DSame, +} + +impl SqueezeExcitation { + fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result<Self> { + let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?; + let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?; + Ok(Self { fc1, fc2 }) + } +} + +impl Module for SqueezeExcitation { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let residual = xs; + // equivalent to adaptive_avg_pool2d([1, 1]) + let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?; + let xs = self.fc1.forward(&xs)?; + let xs = swish(&xs)?; + let xs = self.fc2.forward(&xs)?; + let xs = nn::ops::sigmoid(&xs)?; + residual.broadcast_mul(&xs) + } +} + +#[derive(Debug)] +struct MBConv { + expand_cna: Option<ConvNormActivation>, + depthwise_cna: ConvNormActivation, + squeeze_excitation: SqueezeExcitation, + project_cna: ConvNormActivation, + config: MBConvConfig, +} + +impl MBConv { + fn new(vb: VarBuilder, c: MBConvConfig) -> Result<Self> { + let vb = vb.pp("block"); + let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8); + let expand_cna = if exp != c.input_channels { + Some(ConvNormActivation::new( + vb.pp("0"), + c.input_channels, + exp, + 1, + 1, + 1, + )?) + } else { + None + }; + let start_index = if expand_cna.is_some() { 1 } else { 0 }; + let depthwise_cna = + ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?; + let squeeze_channels = usize::max(1, c.input_channels / 4); + let squeeze_excitation = + SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?; + let project_cna = + ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)? + .no_activation(); + Ok(Self { + expand_cna, + depthwise_cna, + squeeze_excitation, + project_cna, + config: c, + }) + } +} + +impl Module for MBConv { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let use_res_connect = + self.config.stride == 1 && self.config.input_channels == self.config.out_channels; + let ys = match &self.expand_cna { + Some(expand_cna) => expand_cna.forward(xs)?, + None => xs.clone(), + }; + let ys = self.depthwise_cna.forward(&ys)?; + let ys = self.squeeze_excitation.forward(&ys)?; + let ys = self.project_cna.forward(&ys)?; + if use_res_connect { + ys + xs + } else { + Ok(ys) + } + } +} + +fn swish(s: &Tensor) -> Result<Tensor> { + s * nn::ops::sigmoid(s)? +} + +#[derive(Debug)] +pub struct EfficientNet { + init_cna: ConvNormActivation, + blocks: Vec<MBConv>, + final_cna: ConvNormActivation, + classifier: nn::Linear, +} + +impl EfficientNet { + pub fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> { + let f_p = p.pp("features"); + let first_in_c = configs[0].input_channels; + let last_out_c = configs.last().unwrap().out_channels; + let final_out_c = 4 * last_out_c; + let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?; + let nconfigs = configs.len(); + let mut blocks = vec![]; + for (index, cnf) in configs.into_iter().enumerate() { + let f_p = f_p.pp(index + 1); + for r_index in 0..cnf.num_layers { + let cnf = if r_index == 0 { + cnf + } else { + MBConvConfig { + input_channels: cnf.out_channels, + stride: 1, + ..cnf + } + }; + blocks.push(MBConv::new(f_p.pp(r_index), cnf)?) + } + } + let final_cna = + ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?; + let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?; + Ok(Self { + init_cna, + blocks, + final_cna, + classifier, + }) + } +} + +impl Module for EfficientNet { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = self.init_cna.forward(xs)?; + for block in self.blocks.iter() { + xs = block.forward(&xs)? + } + let xs = self.final_cna.forward(&xs)?; + // Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1) + let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?; + self.classifier.forward(&xs) + } +} |