diff options
Diffstat (limited to 'candle-transformers/src/models/mimi/seanet.rs')
-rw-r--r-- | candle-transformers/src/models/mimi/seanet.rs | 465 |
1 files changed, 465 insertions, 0 deletions
diff --git a/candle-transformers/src/models/mimi/seanet.rs b/candle-transformers/src/models/mimi/seanet.rs new file mode 100644 index 00000000..aa5c7d21 --- /dev/null +++ b/candle-transformers/src/models/mimi/seanet.rs @@ -0,0 +1,465 @@ +// Copyright (c) Kyutai, all rights reserved. +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +use candle::{streaming, Module, Result, StreamTensor, StreamingModule, Tensor}; +use candle_nn::VarBuilder; + +use super::conv::{StreamableConv1d, StreamableConvTranspose1d}; + +#[derive(Debug, Clone)] +pub struct Config { + pub dimension: usize, + pub channels: usize, + pub causal: bool, + pub n_filters: usize, + pub n_residual_layers: usize, + pub ratios: Vec<usize>, + pub activation: candle_nn::Activation, + pub norm: super::conv::Norm, + pub kernel_size: usize, + pub residual_kernel_size: usize, + pub last_kernel_size: usize, + pub dilation_base: usize, + pub pad_mode: super::conv::PadMode, + pub true_skip: bool, + pub compress: usize, + pub lstm: usize, + pub disable_norm_outer_blocks: usize, + pub final_activation: Option<candle_nn::Activation>, +} + +#[derive(Debug, Clone)] +pub struct SeaNetResnetBlock { + block: Vec<StreamableConv1d>, + shortcut: Option<StreamableConv1d>, + activation: candle_nn::Activation, + skip_op: candle::StreamingBinOp, + span: tracing::Span, +} + +impl SeaNetResnetBlock { + #[allow(clippy::too_many_arguments)] + pub fn new( + dim: usize, + k_sizes_and_dilations: &[(usize, usize)], + activation: candle_nn::Activation, + norm: Option<super::conv::Norm>, + causal: bool, + pad_mode: super::conv::PadMode, + compress: usize, + true_skip: bool, + vb: VarBuilder, + ) -> Result<Self> { + let mut block = Vec::with_capacity(k_sizes_and_dilations.len()); + let hidden = dim / compress; + let vb_b = vb.pp("block"); + for (i, (k_size, dilation)) in k_sizes_and_dilations.iter().enumerate() { + let in_c = if i == 0 { dim } else { hidden }; + let out_c = if i == k_sizes_and_dilations.len() - 1 { + dim + } else { + hidden + }; + let c = StreamableConv1d::new( + in_c, + out_c, + /* k_size */ *k_size, + /* stride */ 1, + /* dilation */ *dilation, + /* groups */ 1, + /* bias */ true, + /* causal */ causal, + /* norm */ norm, + /* pad_mode */ pad_mode, + vb_b.pp(2 * i + 1), + )?; + block.push(c) + } + let shortcut = if true_skip { + None + } else { + let c = StreamableConv1d::new( + dim, + dim, + /* k_size */ 1, + /* stride */ 1, + /* dilation */ 1, + /* groups */ 1, + /* bias */ true, + /* causal */ causal, + /* norm */ norm, + /* pad_mode */ pad_mode, + vb.pp("shortcut"), + )?; + Some(c) + }; + Ok(Self { + block, + shortcut, + activation, + skip_op: streaming::StreamingBinOp::new(streaming::BinOp::Add, candle::D::Minus1), + span: tracing::span!(tracing::Level::TRACE, "sea-resnet"), + }) + } +} + +impl Module for SeaNetResnetBlock { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let mut ys = xs.clone(); + for block in self.block.iter() { + ys = ys.apply(&self.activation)?.apply(block)?; + } + match self.shortcut.as_ref() { + None => ys + xs, + Some(shortcut) => ys + xs.apply(shortcut), + } + } +} + +impl StreamingModule for SeaNetResnetBlock { + fn reset_state(&mut self) { + for block in self.block.iter_mut() { + block.reset_state() + } + if let Some(shortcut) = self.shortcut.as_mut() { + shortcut.reset_state() + } + } + + fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> { + let _enter = self.span.enter(); + let mut ys = xs.clone(); + for block in self.block.iter_mut() { + ys = block.step(&ys.apply(&self.activation)?)?; + } + match self.shortcut.as_ref() { + None => self.skip_op.step(&ys, xs), + Some(shortcut) => self.skip_op.step(&ys, &xs.apply(shortcut)?), + } + } +} + +#[derive(Debug, Clone)] +struct EncoderLayer { + residuals: Vec<SeaNetResnetBlock>, + downsample: StreamableConv1d, +} + +#[derive(Debug, Clone)] +pub struct SeaNetEncoder { + init_conv1d: StreamableConv1d, + activation: candle_nn::Activation, + layers: Vec<EncoderLayer>, + final_conv1d: StreamableConv1d, + span: tracing::Span, +} + +impl SeaNetEncoder { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + if cfg.lstm > 0 { + candle::bail!("seanet lstm is not supported") + } + let n_blocks = 2 + cfg.ratios.len(); + let mut mult = 1usize; + let init_norm = if cfg.disable_norm_outer_blocks >= 1 { + None + } else { + Some(cfg.norm) + }; + let mut layer_idx = 0; + let vb = vb.pp("layers"); + let init_conv1d = StreamableConv1d::new( + cfg.channels, + mult * cfg.n_filters, + cfg.kernel_size, + /* stride */ 1, + /* dilation */ 1, + /* groups */ 1, + /* bias */ true, + /* causal */ cfg.causal, + /* norm */ init_norm, + /* pad_mode */ cfg.pad_mode, + vb.pp(layer_idx), + )?; + layer_idx += 1; + let mut layers = Vec::with_capacity(cfg.ratios.len()); + + for (i, &ratio) in cfg.ratios.iter().rev().enumerate() { + let norm = if cfg.disable_norm_outer_blocks >= i + 2 { + None + } else { + Some(cfg.norm) + }; + let mut residuals = Vec::with_capacity(cfg.n_residual_layers); + for j in 0..cfg.n_residual_layers { + let resnet_block = SeaNetResnetBlock::new( + mult * cfg.n_filters, + &[ + (cfg.residual_kernel_size, cfg.dilation_base.pow(j as u32)), + (1, 1), + ], + cfg.activation, + norm, + cfg.causal, + cfg.pad_mode, + cfg.compress, + cfg.true_skip, + vb.pp(layer_idx), + )?; + residuals.push(resnet_block); + layer_idx += 1; + } + let downsample = StreamableConv1d::new( + mult * cfg.n_filters, + mult * cfg.n_filters * 2, + /* k_size */ ratio * 2, + /* stride */ ratio, + /* dilation */ 1, + /* groups */ 1, + /* bias */ true, + /* causal */ true, + /* norm */ norm, + /* pad_mode */ cfg.pad_mode, + vb.pp(layer_idx + 1), + )?; + layer_idx += 2; + let layer = EncoderLayer { + downsample, + residuals, + }; + layers.push(layer); + mult *= 2 + } + + let final_norm = if cfg.disable_norm_outer_blocks >= n_blocks { + None + } else { + Some(cfg.norm) + }; + let final_conv1d = StreamableConv1d::new( + mult * cfg.n_filters, + cfg.dimension, + cfg.last_kernel_size, + /* stride */ 1, + /* dilation */ 1, + /* groups */ 1, + /* bias */ true, + /* causal */ cfg.causal, + /* norm */ final_norm, + /* pad_mode */ cfg.pad_mode, + vb.pp(layer_idx + 1), + )?; + Ok(Self { + init_conv1d, + activation: cfg.activation, + layers, + final_conv1d, + span: tracing::span!(tracing::Level::TRACE, "sea-encoder"), + }) + } +} + +impl Module for SeaNetEncoder { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let mut xs = xs.apply(&self.init_conv1d)?; + for layer in self.layers.iter() { + for residual in layer.residuals.iter() { + xs = xs.apply(residual)? + } + xs = xs.apply(&self.activation)?.apply(&layer.downsample)?; + } + xs.apply(&self.activation)?.apply(&self.final_conv1d) + } +} + +impl StreamingModule for SeaNetEncoder { + fn reset_state(&mut self) { + self.init_conv1d.reset_state(); + self.layers.iter_mut().for_each(|v| { + v.residuals.iter_mut().for_each(|v| v.reset_state()); + v.downsample.reset_state() + }); + self.final_conv1d.reset_state(); + } + + fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> { + let _enter = self.span.enter(); + let mut xs = self.init_conv1d.step(xs)?; + for layer in self.layers.iter_mut() { + for residual in layer.residuals.iter_mut() { + xs = residual.step(&xs)?; + } + xs = layer.downsample.step(&xs.apply(&self.activation)?)?; + } + self.final_conv1d.step(&xs.apply(&self.activation)?) + } +} + +#[derive(Debug, Clone)] +struct DecoderLayer { + upsample: StreamableConvTranspose1d, + residuals: Vec<SeaNetResnetBlock>, +} + +#[derive(Debug, Clone)] +pub struct SeaNetDecoder { + init_conv1d: StreamableConv1d, + activation: candle_nn::Activation, + layers: Vec<DecoderLayer>, + final_conv1d: StreamableConv1d, + final_activation: Option<candle_nn::Activation>, + span: tracing::Span, +} + +impl SeaNetDecoder { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + if cfg.lstm > 0 { + candle::bail!("seanet lstm is not supported") + } + let n_blocks = 2 + cfg.ratios.len(); + let mut mult = 1 << cfg.ratios.len(); + let init_norm = if cfg.disable_norm_outer_blocks == n_blocks { + None + } else { + Some(cfg.norm) + }; + let mut layer_idx = 0; + let vb = vb.pp("layers"); + let init_conv1d = StreamableConv1d::new( + cfg.dimension, + mult * cfg.n_filters, + cfg.kernel_size, + /* stride */ 1, + /* dilation */ 1, + /* groups */ 1, + /* bias */ true, + /* causal */ cfg.causal, + /* norm */ init_norm, + /* pad_mode */ cfg.pad_mode, + vb.pp(layer_idx), + )?; + layer_idx += 1; + let mut layers = Vec::with_capacity(cfg.ratios.len()); + for (i, &ratio) in cfg.ratios.iter().enumerate() { + let norm = if cfg.disable_norm_outer_blocks + i + 1 >= n_blocks { + None + } else { + Some(cfg.norm) + }; + let upsample = StreamableConvTranspose1d::new( + mult * cfg.n_filters, + mult * cfg.n_filters / 2, + /* k_size */ ratio * 2, + /* stride */ ratio, + /* groups */ 1, + /* bias */ true, + /* causal */ true, + /* norm */ norm, + vb.pp(layer_idx + 1), + )?; + layer_idx += 2; + + let mut residuals = Vec::with_capacity(cfg.n_residual_layers); + for j in 0..cfg.n_residual_layers { + let resnet_block = SeaNetResnetBlock::new( + mult * cfg.n_filters / 2, + &[ + (cfg.residual_kernel_size, cfg.dilation_base.pow(j as u32)), + (1, 1), + ], + cfg.activation, + norm, + cfg.causal, + cfg.pad_mode, + cfg.compress, + cfg.true_skip, + vb.pp(layer_idx), + )?; + residuals.push(resnet_block); + layer_idx += 1; + } + let layer = DecoderLayer { + upsample, + residuals, + }; + layers.push(layer); + mult /= 2 + } + let final_norm = if cfg.disable_norm_outer_blocks >= 1 { + None + } else { + Some(cfg.norm) + }; + let final_conv1d = StreamableConv1d::new( + cfg.n_filters, + cfg.channels, + cfg.last_kernel_size, + /* stride */ 1, + /* dilation */ 1, + /* groups */ 1, + /* bias */ true, + /* causal */ cfg.causal, + /* norm */ final_norm, + /* pad_mode */ cfg.pad_mode, + vb.pp(layer_idx + 1), + )?; + Ok(Self { + init_conv1d, + activation: cfg.activation, + layers, + final_conv1d, + final_activation: cfg.final_activation, + span: tracing::span!(tracing::Level::TRACE, "sea-decoder"), + }) + } +} + +impl Module for SeaNetDecoder { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let mut xs = xs.apply(&self.init_conv1d)?; + for layer in self.layers.iter() { + xs = xs.apply(&self.activation)?.apply(&layer.upsample)?; + for residual in layer.residuals.iter() { + xs = xs.apply(residual)? + } + } + let xs = xs.apply(&self.activation)?.apply(&self.final_conv1d)?; + let xs = match self.final_activation.as_ref() { + None => xs, + Some(act) => xs.apply(act)?, + }; + Ok(xs) + } +} + +impl StreamingModule for SeaNetDecoder { + fn reset_state(&mut self) { + self.init_conv1d.reset_state(); + self.layers.iter_mut().for_each(|v| { + v.residuals.iter_mut().for_each(|v| v.reset_state()); + v.upsample.reset_state() + }); + self.final_conv1d.reset_state(); + } + + fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> { + let _enter = self.span.enter(); + let mut xs = self.init_conv1d.step(xs)?; + for layer in self.layers.iter_mut() { + xs = layer.upsample.step(&xs.apply(&self.activation)?)?; + for residual in layer.residuals.iter_mut() { + xs = residual.step(&xs)?; + } + } + let xs = self.final_conv1d.step(&xs.apply(&self.activation)?)?; + let xs = match self.final_activation.as_ref() { + None => xs, + Some(act) => xs.apply(act)?, + }; + Ok(xs) + } +} |