summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/mimi/conv.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/mimi/conv.rs')
-rw-r--r--candle-transformers/src/models/mimi/conv.rs670
1 files changed, 670 insertions, 0 deletions
diff --git a/candle-transformers/src/models/mimi/conv.rs b/candle-transformers/src/models/mimi/conv.rs
new file mode 100644
index 00000000..87e9fb4c
--- /dev/null
+++ b/candle-transformers/src/models/mimi/conv.rs
@@ -0,0 +1,670 @@
+// Copyright (c) Kyutai, all rights reserved.
+// This source code is licensed under the license found in the
+// LICENSE file in the root directory of this source tree.
+
+use candle::{Module, Result, StreamTensor, StreamingModule, Tensor, D};
+use candle_nn::{Conv1d, VarBuilder};
+
+#[allow(clippy::enum_variant_names)]
+#[derive(Debug, Copy, Clone, PartialEq, Eq)]
+pub enum Norm {
+ WeightNorm,
+ SpectralNorm,
+ TimeGroupNorm,
+}
+
+#[derive(Debug, Copy, Clone, PartialEq, Eq)]
+pub enum PadMode {
+ Constant,
+ Reflect,
+ Replicate,
+}
+
+// Applies weight norm for inference by recomputing the weight tensor. This
+// does not apply to training.
+// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
+fn conv1d_weight_norm(
+ in_c: usize,
+ out_c: usize,
+ kernel_size: usize,
+ bias: bool,
+ config: candle_nn::Conv1dConfig,
+ vb: VarBuilder,
+) -> Result<Conv1d> {
+ let weight = if vb.contains_tensor("weight") {
+ vb.get((out_c, in_c, kernel_size), "weight")?
+ } else {
+ let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
+ let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
+ let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
+ weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?
+ };
+ let bias = if bias {
+ Some(vb.get(out_c, "bias")?)
+ } else {
+ None
+ };
+ Ok(Conv1d::new(weight, bias, config))
+}
+
+#[derive(Debug, Clone)]
+pub struct NormConv1d {
+ conv: Conv1d,
+ norm: Option<candle_nn::GroupNorm>,
+ span: tracing::Span,
+}
+
+impl NormConv1d {
+ #[allow(clippy::too_many_arguments)]
+ pub fn new(
+ in_c: usize,
+ out_c: usize,
+ k_size: usize,
+ causal: bool,
+ norm: Option<Norm>,
+ bias: bool,
+ cfg: candle_nn::Conv1dConfig,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let conv = match norm {
+ None | Some(Norm::TimeGroupNorm) => {
+ if bias {
+ candle_nn::conv1d(in_c, out_c, k_size, cfg, vb.pp("conv"))?
+ } else {
+ candle_nn::conv1d_no_bias(in_c, out_c, k_size, cfg, vb.pp("conv"))?
+ }
+ }
+ Some(Norm::WeightNorm) => {
+ conv1d_weight_norm(in_c, out_c, k_size, bias, cfg, vb.pp("conv"))?
+ }
+ Some(Norm::SpectralNorm) => candle::bail!("SpectralNorm is not supported yet."),
+ };
+ let norm = match norm {
+ None | Some(Norm::WeightNorm) | Some(Norm::SpectralNorm) => None,
+ Some(Norm::TimeGroupNorm) => {
+ if causal {
+ candle::bail!("GroupNorm doesn't support causal evaluation.")
+ }
+ let norm = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
+ Some(norm)
+ }
+ };
+ Ok(Self {
+ conv,
+ norm,
+ span: tracing::span!(tracing::Level::TRACE, "norm-conv1d"),
+ })
+ }
+}
+
+impl Module for NormConv1d {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let xs = xs.apply(&self.conv)?;
+ match self.norm.as_ref() {
+ None => Ok(xs),
+ Some(norm) => xs.apply(norm),
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct NormConvTranspose1d {
+ ws: Tensor,
+ bs: Option<Tensor>,
+ k_size: usize,
+ stride: usize,
+ groups: usize,
+ norm: Option<candle_nn::GroupNorm>,
+ span: tracing::Span,
+}
+
+impl NormConvTranspose1d {
+ #[allow(clippy::too_many_arguments)]
+ pub fn new(
+ in_c: usize,
+ out_c: usize,
+ k_size: usize,
+ causal: bool,
+ norm: Option<Norm>,
+ bias: bool,
+ stride: usize,
+ groups: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let vb = vb.pp("conv");
+ let bs = if bias {
+ Some(vb.get(out_c, "bias")?)
+ } else {
+ None
+ };
+ let ws = match norm {
+ None | Some(Norm::TimeGroupNorm) => vb.get((in_c, out_c / groups, k_size), "weight")?,
+ Some(Norm::WeightNorm) => {
+ if vb.contains_tensor("weight") {
+ vb.get((in_c, out_c, k_size), "weight")?
+ } else {
+ let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
+ let weight_v = vb.get((in_c, out_c, k_size), "weight_v")?;
+ let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
+ weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?
+ }
+ }
+ Some(Norm::SpectralNorm) => candle::bail!("SpectralNorm is not supported yet."),
+ };
+ let (ws, groups) = if groups == out_c && in_c == out_c {
+ let eye = Tensor::eye(out_c, ws.dtype(), ws.device())?;
+ let ws = ws
+ .repeat((1, out_c, 1))?
+ .mul(&eye.unsqueeze(2)?.repeat((1, 1, k_size))?)?;
+ (ws, 1)
+ } else {
+ (ws, groups)
+ };
+ let norm = match norm {
+ None | Some(Norm::WeightNorm) | Some(Norm::SpectralNorm) => None,
+ Some(Norm::TimeGroupNorm) => {
+ if causal {
+ candle::bail!("GroupNorm doesn't support causal evaluation.")
+ }
+ let norm = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
+ Some(norm)
+ }
+ };
+ Ok(Self {
+ ws,
+ bs,
+ k_size,
+ stride,
+ groups,
+ norm,
+ span: tracing::span!(tracing::Level::TRACE, "norm-conv-tr1d"),
+ })
+ }
+}
+
+impl Module for NormConvTranspose1d {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ // conv-transpose1d seems to be broken on metal after enough iterations. Causing
+ // the following error:
+ // _status < MTLCommandBufferStatusCommitted >
+ // -[IOGPUMetalCommandBuffer setCurrentCommandEncoder:]
+ // This is now fixed in candle.
+ let xs = Tensor::conv_transpose1d(xs, &self.ws, 0, 0, self.stride, 1, self.groups)?;
+ let xs = match &self.bs {
+ None => xs,
+ Some(bias) => {
+ let b = bias.dims1()?;
+ let bias = bias.reshape((1, b, 1))?;
+ xs.broadcast_add(&bias)?
+ }
+ };
+ match self.norm.as_ref() {
+ None => Ok(xs),
+ Some(norm) => xs.apply(norm),
+ }
+ }
+}
+
+fn get_extra_padding_for_conv1d(
+ xs: &Tensor,
+ k_size: usize,
+ stride: usize,
+ padding_total: usize,
+) -> Result<usize> {
+ let len = xs.dim(D::Minus1)?;
+ let n_frames = (len + padding_total).saturating_sub(k_size) as f64 / stride as f64 + 1.0;
+ let ideal_len =
+ ((n_frames.ceil() as usize - 1) * stride + k_size).saturating_sub(padding_total);
+ Ok(ideal_len.saturating_sub(len))
+}
+
+fn pad1d(xs: &Tensor, pad_l: usize, pad_r: usize, mode: PadMode) -> Result<Tensor> {
+ match mode {
+ PadMode::Constant => xs.pad_with_zeros(D::Minus1, pad_l, pad_r),
+ PadMode::Reflect => candle::bail!("pad-mode 'reflect' is not supported"),
+ PadMode::Replicate => xs.pad_with_same(D::Minus1, pad_l, pad_r),
+ }
+}
+
+fn unpad1d(xs: &Tensor, unpad_l: usize, unpad_r: usize) -> Result<Tensor> {
+ let len = xs.dim(D::Minus1)?;
+ if len < unpad_l + unpad_r {
+ candle::bail!("unpad1d: tensor len {len} is too low, {unpad_l} + {unpad_r}")
+ }
+ xs.narrow(D::Minus1, unpad_l, len - (unpad_l + unpad_r))
+}
+
+#[derive(Debug, Clone)]
+pub struct StreamableConv1d {
+ conv: NormConv1d,
+ causal: bool,
+ pad_mode: PadMode,
+ state_prev_xs: StreamTensor,
+ left_pad_applied: bool,
+ kernel_size: usize,
+ span: tracing::Span,
+}
+
+impl StreamableConv1d {
+ #[allow(clippy::too_many_arguments)]
+ pub fn new(
+ in_c: usize,
+ out_c: usize,
+ k_size: usize,
+ stride: usize,
+ dilation: usize,
+ groups: usize,
+ bias: bool,
+ causal: bool,
+ norm: Option<Norm>,
+ pad_mode: PadMode,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let cfg = candle_nn::Conv1dConfig {
+ padding: 0,
+ stride,
+ dilation,
+ groups,
+ };
+ let conv = NormConv1d::new(in_c, out_c, k_size, causal, norm, bias, cfg, vb)?;
+ if k_size < stride {
+ candle::bail!("kernel-size {k_size} is smaller than stride {stride}")
+ }
+ Ok(Self {
+ conv,
+ causal,
+ pad_mode,
+ state_prev_xs: StreamTensor::empty(),
+ left_pad_applied: false,
+ kernel_size: k_size,
+ span: tracing::span!(tracing::Level::TRACE, "streamable-conv1d"),
+ })
+ }
+}
+
+impl Module for StreamableConv1d {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let (_b, _t, _c) = xs.dims3()?;
+ let k_size = self.conv.conv.weight().dim(D::Minus1)?;
+ let conv_cfg = self.conv.conv.config();
+ // Effective kernel size with dilations.
+ let k_size = (k_size - 1) * conv_cfg.dilation + 1;
+ let padding_total = k_size - conv_cfg.stride;
+ let extra_padding =
+ get_extra_padding_for_conv1d(xs, k_size, conv_cfg.stride, padding_total)?;
+ let xs = if self.causal {
+ pad1d(xs, padding_total, extra_padding, self.pad_mode)?
+ } else {
+ let padding_right = padding_total / 2;
+ let padding_left = padding_total - padding_right;
+ pad1d(
+ xs,
+ padding_left,
+ padding_right + extra_padding,
+ self.pad_mode,
+ )?
+ };
+ xs.apply(&self.conv)
+ }
+}
+
+impl StreamingModule for StreamableConv1d {
+ fn reset_state(&mut self) {
+ self.state_prev_xs.reset();
+ self.left_pad_applied = false;
+ }
+
+ fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
+ let _enter = self.span.enter();
+ let xs = match xs.as_option() {
+ None => return Ok(().into()),
+ Some(xs) => xs.clone(),
+ };
+ let xs = if self.left_pad_applied {
+ xs
+ } else {
+ self.left_pad_applied = true;
+ let k_size = self.conv.conv.weight().dim(D::Minus1)?;
+ let conv_cfg = self.conv.conv.config();
+ let k_size = (k_size - 1) * conv_cfg.dilation + 1;
+ let padding_total = k_size - conv_cfg.stride;
+ pad1d(&xs, padding_total, 0, self.pad_mode)?
+ };
+ let cfg = self.conv.conv.config();
+ let stride = cfg.stride;
+ let dilation = cfg.dilation;
+ let kernel = (self.kernel_size - 1) * dilation + 1;
+ let xs = StreamTensor::cat2(&self.state_prev_xs, &xs.into(), D::Minus1)?;
+ let seq_len = xs.seq_len(D::Minus1)?;
+ let num_frames = (seq_len + stride).saturating_sub(kernel) / stride;
+ if num_frames > 0 {
+ let offset = num_frames * stride;
+ self.state_prev_xs = xs.narrow(D::Minus1, offset, seq_len - offset)?;
+ let in_l = (num_frames - 1) * stride + kernel;
+ let xs = xs.narrow(D::Minus1, 0, in_l)?;
+ // We apply the underlying convtr directly rather than through forward so as
+ // not to apply any padding here.
+ xs.apply(&self.conv.conv)
+ } else {
+ self.state_prev_xs = xs;
+ Ok(StreamTensor::empty())
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct StreamableConvTranspose1d {
+ convtr: NormConvTranspose1d,
+ causal: bool,
+ state_prev_ys: StreamTensor,
+ kernel_size: usize,
+ span: tracing::Span,
+}
+
+impl StreamableConvTranspose1d {
+ #[allow(clippy::too_many_arguments)]
+ pub fn new(
+ in_c: usize,
+ out_c: usize,
+ k_size: usize,
+ stride: usize,
+ groups: usize,
+ bias: bool,
+ causal: bool,
+ norm: Option<Norm>,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let convtr =
+ NormConvTranspose1d::new(in_c, out_c, k_size, causal, norm, bias, stride, groups, vb)?;
+ Ok(Self {
+ convtr,
+ causal,
+ kernel_size: k_size,
+ state_prev_ys: StreamTensor::empty(),
+ span: tracing::span!(tracing::Level::TRACE, "streamable-conv-tr1d"),
+ })
+ }
+}
+
+impl Module for StreamableConvTranspose1d {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let k_size = self.convtr.k_size;
+ let stride = self.convtr.stride;
+ let padding_total = k_size.saturating_sub(stride);
+ let xs = xs.apply(&self.convtr)?;
+ if self.causal {
+ // This corresponds to trim_right_ratio = 1.
+ unpad1d(&xs, 0, padding_total)
+ } else {
+ let padding_right = padding_total / 2;
+ let padding_left = padding_total - padding_right;
+ unpad1d(&xs, padding_left, padding_right)
+ }
+ }
+}
+
+impl StreamingModule for StreamableConvTranspose1d {
+ fn reset_state(&mut self) {
+ self.state_prev_ys.reset()
+ }
+
+ fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
+ let _enter = self.span.enter();
+ let xs = match xs.as_option() {
+ Some(xs) => xs,
+ None => return Ok(StreamTensor::empty()),
+ };
+ let stride = self.convtr.stride;
+ // We apply the underlying convtr directly rather than through forward so as
+ // not to apply any padding here.
+ let ys = self.convtr.forward(xs)?;
+ let ot = ys.dim(D::Minus1)?;
+ let ys = match self.state_prev_ys.as_option() {
+ None => ys,
+ Some(prev_ys) => {
+ let pt = prev_ys.dim(D::Minus1)?;
+ // Remove the bias as it will be applied multiple times.
+ let prev_ys = match &self.convtr.bs {
+ None => prev_ys.clone(),
+ Some(bias) => {
+ let bias = bias.reshape((1, (), 1))?;
+ prev_ys.broadcast_sub(&bias)?
+ }
+ };
+ let ys1 = (ys.narrow(D::Minus1, 0, pt)? + prev_ys)?;
+ let ys2 = ys.narrow(D::Minus1, pt, ot - pt)?;
+ Tensor::cat(&[ys1, ys2], D::Minus1)?
+ }
+ };
+ let invalid_steps = self.kernel_size - stride;
+ let (ys, prev_ys) = StreamTensor::from(ys).split(D::Minus1, ot - invalid_steps)?;
+ self.state_prev_ys = prev_ys;
+ Ok(ys)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct ConvDownsample1d {
+ conv: StreamableConv1d,
+}
+
+impl ConvDownsample1d {
+ pub fn new(
+ stride: usize,
+ dim: usize,
+ causal: bool,
+ learnt: bool,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ if !learnt {
+ candle::bail!("only learnt=true is supported")
+ }
+ let conv = StreamableConv1d::new(
+ /* in_c */ dim,
+ /* out_c */ dim,
+ /* k_size_c */ 2 * stride,
+ /* stride */ stride,
+ /* dilation */ 1,
+ /* groups */ 1, // channel_wise = false
+ /* bias */ false,
+ /* causal */ causal,
+ /* norm */ None,
+ /* pad_mode */ PadMode::Replicate,
+ vb,
+ )?;
+ Ok(Self { conv })
+ }
+}
+
+impl Module for ConvDownsample1d {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.apply(&self.conv)
+ }
+}
+
+impl StreamingModule for ConvDownsample1d {
+ fn reset_state(&mut self) {
+ self.conv.reset_state()
+ }
+
+ fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
+ self.conv.step(xs)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct ConvTrUpsample1d {
+ convtr: StreamableConvTranspose1d,
+}
+
+impl ConvTrUpsample1d {
+ pub fn new(
+ stride: usize,
+ dim: usize,
+ causal: bool,
+ learnt: bool,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ if !learnt {
+ candle::bail!("only learnt=true is supported")
+ }
+ let convtr = StreamableConvTranspose1d::new(
+ dim,
+ dim,
+ /* k_size */ 2 * stride,
+ /* stride */ stride,
+ /* groups */ dim,
+ /* bias */ false,
+ /* causal */ causal,
+ /* norm */ None,
+ vb,
+ )?;
+ Ok(Self { convtr })
+ }
+}
+
+impl Module for ConvTrUpsample1d {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.apply(&self.convtr)
+ }
+}
+
+impl StreamingModule for ConvTrUpsample1d {
+ fn reset_state(&mut self) {
+ self.convtr.reset_state()
+ }
+
+ fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
+ self.convtr.step(xs)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use candle::IndexOp;
+
+ fn run_conv1d(
+ k_size: usize,
+ stride: usize,
+ dilation: usize,
+ step_size: usize,
+ len: usize,
+ bias: bool,
+ ) -> Result<()> {
+ // TODO: We should ensure for the seed to be constant when running these tests.
+ let dev = &candle::Device::Cpu;
+ let vm = candle_nn::VarMap::new();
+ let vb = VarBuilder::from_varmap(&vm, candle::DType::F32, dev);
+ let conv1d = StreamableConv1d::new(
+ /* in_c */ 2,
+ /* out_c */ 3,
+ /* k_size */ k_size,
+ /* stride */ stride,
+ /* dilation */ dilation,
+ /* groups */ 1,
+ /* bias */ bias,
+ /* causal */ true,
+ /* norm */ None,
+ /* pad_mode */ PadMode::Constant,
+ vb,
+ )?;
+ let xs = Tensor::randn(0f32, 1., (1, 2, step_size * len), dev)?;
+ let ys = conv1d.forward(&xs)?;
+ let mut conv1d = conv1d;
+ let mut ys_steps = vec![];
+ for idx in 0..len {
+ let xs = xs.i((.., .., step_size * idx..step_size * (idx + 1)))?;
+ let ys = conv1d.step(&xs.into())?;
+ if let Some(ys) = ys.as_option() {
+ ys_steps.push(ys.clone())
+ }
+ }
+ let ys_steps = Tensor::cat(&ys_steps, D::Minus1)?;
+ let diff = (&ys - &ys_steps)?
+ .abs()?
+ .flatten_all()?
+ .max(0)?
+ .to_vec0::<f32>()?;
+ if diff > 1e-5 {
+ println!("{xs}");
+ println!("{ys}");
+ println!("{ys_steps}");
+ candle::bail!("larger diff than expected {diff}")
+ }
+ Ok(())
+ }
+
+ fn run_conv_tr1d(
+ k_size: usize,
+ stride: usize,
+ step_size: usize,
+ len: usize,
+ bias: bool,
+ ) -> Result<()> {
+ // TODO: We should ensure for the seed to be constant when running these tests.
+ let dev = &candle::Device::Cpu;
+ let vm = candle_nn::VarMap::new();
+ let vb = VarBuilder::from_varmap(&vm, candle::DType::F32, dev);
+ let conv1d = StreamableConvTranspose1d::new(
+ /* in_c */ 2, /* out_c */ 3, /* k_size */ k_size,
+ /* stride */ stride, /* groups */ 1, /* bias */ bias,
+ /* causal */ true, /* norm */ None, vb,
+ )?;
+ let xs = Tensor::randn(0f32, 1., (1, 2, step_size * len), dev)?;
+ let ys = conv1d.forward(&xs)?;
+ let mut conv1d = conv1d;
+ let mut ys_steps = vec![];
+ for idx in 0..len {
+ let xs = xs.i((.., .., step_size * idx..step_size * (idx + 1)))?;
+ let ys = conv1d.step(&xs.into())?;
+ if let Some(ys) = ys.as_option() {
+ ys_steps.push(ys.clone())
+ }
+ }
+ let ys_steps = Tensor::cat(&ys_steps, D::Minus1)?;
+ let diff = (&ys - &ys_steps)?
+ .abs()?
+ .flatten_all()?
+ .max(0)?
+ .to_vec0::<f32>()?;
+ if diff > 1e-5 {
+ println!("{xs}");
+ println!("{ys}");
+ println!("{ys_steps}");
+ candle::bail!("larger diff than expected {diff}")
+ }
+ Ok(())
+ }
+
+ #[test]
+ fn conv1d() -> Result<()> {
+ for step_size in [1, 2, 3] {
+ for bias in [false, true] {
+ run_conv1d(1, 1, 1, step_size, 5, bias)?;
+ run_conv1d(2, 1, 1, step_size, 5, bias)?;
+ run_conv1d(2, 2, 1, step_size, 6, bias)?;
+ run_conv1d(3, 2, 1, step_size, 8, bias)?;
+ run_conv1d(3, 2, 2, step_size, 8, bias)?;
+ }
+ }
+ Ok(())
+ }
+
+ #[test]
+ fn conv_tr1d() -> Result<()> {
+ for step_size in [1, 2, 3] {
+ for bias in [false, true] {
+ run_conv_tr1d(1, 1, step_size, 5, bias)?;
+ run_conv_tr1d(2, 1, step_size, 5, bias)?;
+ run_conv_tr1d(3, 1, step_size, 5, bias)?;
+ run_conv_tr1d(3, 2, step_size, 5, bias)?;
+ }
+ }
+ Ok(())
+ }
+}