summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore3
-rw-r--r--candle-examples/Cargo.toml5
-rw-r--r--candle-examples/examples/mimi/README.md20
-rw-r--r--candle-examples/examples/mimi/audio_io.rs275
-rw-r--r--candle-examples/examples/mimi/main.rs131
-rw-r--r--candle-transformers/src/models/mimi/conv.rs670
-rw-r--r--candle-transformers/src/models/mimi/encodec.rs229
-rw-r--r--candle-transformers/src/models/mimi/mod.rs22
-rw-r--r--candle-transformers/src/models/mimi/quantization.rs404
-rw-r--r--candle-transformers/src/models/mimi/seanet.rs465
-rw-r--r--candle-transformers/src/models/mimi/transformer.rs802
-rw-r--r--candle-transformers/src/models/mod.rs1
12 files changed, 3027 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
index 38a7d504..4dfbcc16 100644
--- a/.gitignore
+++ b/.gitignore
@@ -43,3 +43,6 @@ candle-wasm-examples/**/config*.json
__pycache__
out.safetensors
out.wav
+bria.mp3
+bria.safetensors
+bria.wav
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml
index 6879c48b..543c9666 100644
--- a/candle-examples/Cargo.toml
+++ b/candle-examples/Cargo.toml
@@ -67,6 +67,7 @@ onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"]
microphone = ["cpal"]
encodec = ["cpal", "symphonia", "rubato"]
+mimi = ["cpal", "symphonia", "rubato"]
depth_anything_v2 = ["palette", "enterpolation"]
[[example]]
@@ -102,6 +103,10 @@ name = "llama2-c"
required-features = ["candle-datasets"]
[[example]]
+name = "mimi"
+required-features = ["mimi"]
+
+[[example]]
name = "encodec"
required-features = ["encodec"]
diff --git a/candle-examples/examples/mimi/README.md b/candle-examples/examples/mimi/README.md
new file mode 100644
index 00000000..bbcfcdb7
--- /dev/null
+++ b/candle-examples/examples/mimi/README.md
@@ -0,0 +1,20 @@
+# candle-mimi
+
+[Mimi](https://huggingface.co/kyutai/mimi) is a state of the art audio
+compression model using an encoder/decoder architecture with residual vector
+quantization. The candle implementation supports streaming meaning that it's
+possible to encode or decode a stream of audio tokens on the flight to provide
+low latency interaction with an audio model.
+
+## Running one example
+
+Generating some audio tokens from an audio files.
+```bash
+wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
+cargo run --example mimi --features mimi --release -- audio-to-code bria.mp3 bria.safetensors
+```
+
+And decoding the audio tokens back into a sound file.
+```bash
+cargo run --example mimi --features mimi --release -- code-to-audio bria.safetensors bria.wav
+```
diff --git a/candle-examples/examples/mimi/audio_io.rs b/candle-examples/examples/mimi/audio_io.rs
new file mode 100644
index 00000000..2103dd4a
--- /dev/null
+++ b/candle-examples/examples/mimi/audio_io.rs
@@ -0,0 +1,275 @@
+#![allow(unused)]
+use anyhow::{Context, Result};
+use std::sync::{Arc, Mutex};
+
+pub const SAMPLE_RATE: usize = 24_000;
+
+pub(crate) struct AudioOutputData_ {
+ resampled_data: std::collections::VecDeque<f32>,
+ resampler: rubato::FastFixedIn<f32>,
+ output_buffer: Vec<f32>,
+ input_buffer: Vec<f32>,
+ input_len: usize,
+}
+
+impl AudioOutputData_ {
+ pub(crate) fn new(input_sample_rate: usize, output_sample_rate: usize) -> Result<Self> {
+ use rubato::Resampler;
+
+ let resampled_data = std::collections::VecDeque::with_capacity(output_sample_rate * 10);
+ let resample_ratio = output_sample_rate as f64 / input_sample_rate as f64;
+ let resampler = rubato::FastFixedIn::new(
+ resample_ratio,
+ f64::max(resample_ratio, 1.0),
+ rubato::PolynomialDegree::Septic,
+ 1024,
+ 1,
+ )?;
+ let input_buffer = resampler.input_buffer_allocate(true).remove(0);
+ let output_buffer = resampler.output_buffer_allocate(true).remove(0);
+ Ok(Self {
+ resampled_data,
+ resampler,
+ input_buffer,
+ output_buffer,
+ input_len: 0,
+ })
+ }
+
+ pub fn reset(&mut self) {
+ use rubato::Resampler;
+ self.output_buffer.fill(0.);
+ self.input_buffer.fill(0.);
+ self.resampler.reset();
+ self.resampled_data.clear();
+ }
+
+ pub(crate) fn take_all(&mut self) -> Vec<f32> {
+ let mut data = Vec::with_capacity(self.resampled_data.len());
+ while let Some(elem) = self.resampled_data.pop_back() {
+ data.push(elem);
+ }
+ data
+ }
+
+ pub(crate) fn is_empty(&self) -> bool {
+ self.resampled_data.is_empty()
+ }
+
+ // Assumes that the input buffer is large enough.
+ fn push_input_buffer(&mut self, samples: &[f32]) {
+ self.input_buffer[self.input_len..self.input_len + samples.len()].copy_from_slice(samples);
+ self.input_len += samples.len()
+ }
+
+ pub(crate) fn push_samples(&mut self, samples: &[f32]) -> Result<()> {
+ use rubato::Resampler;
+
+ let mut pos_in = 0;
+ loop {
+ let rem = self.input_buffer.len() - self.input_len;
+ let pos_end = usize::min(pos_in + rem, samples.len());
+ self.push_input_buffer(&samples[pos_in..pos_end]);
+ pos_in = pos_end;
+ if self.input_len < self.input_buffer.len() {
+ break;
+ }
+ let (_, out_len) = self.resampler.process_into_buffer(
+ &[&self.input_buffer],
+ &mut [&mut self.output_buffer],
+ None,
+ )?;
+ for &elem in self.output_buffer[..out_len].iter() {
+ self.resampled_data.push_front(elem)
+ }
+ self.input_len = 0;
+ }
+ Ok(())
+ }
+}
+
+type AudioOutputData = Arc<Mutex<AudioOutputData_>>;
+
+pub(crate) fn setup_output_stream() -> Result<(cpal::Stream, AudioOutputData)> {
+ use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
+
+ println!("Setup audio output stream!");
+ let host = cpal::default_host();
+ let device = host
+ .default_output_device()
+ .context("no output device available")?;
+ let mut supported_configs_range = device.supported_output_configs()?;
+ let config_range = match supported_configs_range.find(|c| c.channels() == 1) {
+ // On macOS, it's commonly the case that there are only stereo outputs.
+ None => device
+ .supported_output_configs()?
+ .next()
+ .context("no audio output available")?,
+ Some(config_range) => config_range,
+ };
+ let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
+ config_range.min_sample_rate(),
+ config_range.max_sample_rate(),
+ );
+ let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
+ let channels = config.channels as usize;
+ println!(
+ "cpal device: {} {} {config:?}",
+ device.name().unwrap_or_else(|_| "unk".to_string()),
+ config.sample_rate.0
+ );
+ let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
+ SAMPLE_RATE,
+ config.sample_rate.0 as usize,
+ )?));
+ let ad = audio_data.clone();
+ let stream = device.build_output_stream(
+ &config,
+ move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
+ data.fill(0.);
+ let mut ad = ad.lock().unwrap();
+ let mut last_elem = 0f32;
+ for (idx, elem) in data.iter_mut().enumerate() {
+ if idx % channels == 0 {
+ match ad.resampled_data.pop_back() {
+ None => break,
+ Some(v) => {
+ last_elem = v;
+ *elem = v
+ }
+ }
+ } else {
+ *elem = last_elem
+ }
+ }
+ },
+ move |err| eprintln!("cpal error: {err}"),
+ None, // None=blocking, Some(Duration)=timeout
+ )?;
+ stream.play()?;
+ Ok((stream, audio_data))
+}
+
+pub(crate) fn setup_input_stream() -> Result<(cpal::Stream, AudioOutputData)> {
+ use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
+
+ println!("Setup audio input stream!");
+ let host = cpal::default_host();
+ let device = host
+ .default_input_device()
+ .context("no input device available")?;
+ let mut supported_configs_range = device.supported_input_configs()?;
+ let config_range = supported_configs_range
+ .find(|c| c.channels() == 1)
+ .context("no audio input available")?;
+ let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
+ config_range.min_sample_rate(),
+ config_range.max_sample_rate(),
+ );
+ let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
+ println!(
+ "cpal device: {} {} {config:?}",
+ device.name().unwrap_or_else(|_| "unk".to_string()),
+ config.sample_rate.0
+ );
+ let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
+ config.sample_rate.0 as usize,
+ SAMPLE_RATE,
+ )?));
+ let ad = audio_data.clone();
+ let stream = device.build_input_stream(
+ &config,
+ move |data: &[f32], _: &cpal::InputCallbackInfo| {
+ let mut ad = ad.lock().unwrap();
+ if let Err(err) = ad.push_samples(data) {
+ eprintln!("error processing audio input {err:?}")
+ }
+ },
+ move |err| eprintln!("cpal error: {err}"),
+ None, // None=blocking, Some(Duration)=timeout
+ )?;
+ stream.play()?;
+ Ok((stream, audio_data))
+}
+
+fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
+where
+ T: symphonia::core::sample::Sample,
+ f32: symphonia::core::conv::FromSample<T>,
+{
+ use symphonia::core::audio::Signal;
+ use symphonia::core::conv::FromSample;
+ samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
+}
+
+pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
+ use symphonia::core::audio::{AudioBufferRef, Signal};
+
+ let src = std::fs::File::open(path)?;
+ let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
+ let hint = symphonia::core::probe::Hint::new();
+ let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
+ let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
+ let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
+ let mut format = probed.format;
+ let track = format
+ .tracks()
+ .iter()
+ .find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
+ .expect("no supported audio tracks");
+ let mut decoder = symphonia::default::get_codecs()
+ .make(&track.codec_params, &Default::default())
+ .expect("unsupported codec");
+ let track_id = track.id;
+ let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
+ let mut pcm_data = Vec::new();
+ while let Ok(packet) = format.next_packet() {
+ while !format.metadata().is_latest() {
+ format.metadata().pop();
+ }
+ if packet.track_id() != track_id {
+ continue;
+ }
+ match decoder.decode(&packet)? {
+ AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
+ AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
+ AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
+ AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
+ AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
+ AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
+ AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
+ AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
+ AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
+ AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
+ }
+ }
+ Ok((pcm_data, sample_rate))
+}
+
+pub(crate) fn resample(pcm_in: &[f32], sr_in: usize, sr_out: usize) -> Result<Vec<f32>> {
+ use rubato::Resampler;
+
+ let mut pcm_out =
+ Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
+
+ let mut resampler = rubato::FftFixedInOut::<f32>::new(sr_in, sr_out, 1024, 1)?;
+ let mut output_buffer = resampler.output_buffer_allocate(true);
+ let mut pos_in = 0;
+ while pos_in + resampler.input_frames_next() < pcm_in.len() {
+ let (in_len, out_len) =
+ resampler.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)?;
+ pos_in += in_len;
+ pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
+ }
+
+ if pos_in < pcm_in.len() {
+ let (_in_len, out_len) = resampler.process_partial_into_buffer(
+ Some(&[&pcm_in[pos_in..]]),
+ &mut output_buffer,
+ None,
+ )?;
+ pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
+ }
+
+ Ok(pcm_out)
+}
diff --git a/candle-examples/examples/mimi/main.rs b/candle-examples/examples/mimi/main.rs
new file mode 100644
index 00000000..cfc1a553
--- /dev/null
+++ b/candle-examples/examples/mimi/main.rs
@@ -0,0 +1,131 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+use anyhow::Result;
+use candle::{DType, IndexOp, Tensor};
+use candle_nn::VarBuilder;
+use candle_transformers::models::mimi::{Config, Model};
+use clap::{Parser, ValueEnum};
+use hf_hub::api::sync::Api;
+
+mod audio_io;
+
+#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
+enum Action {
+ AudioToAudio,
+ AudioToCode,
+ CodeToAudio,
+}
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// The action to be performed, specifies the format for the input and output data.
+ action: Action,
+
+ /// The input file, either an audio file or some mimi tokens stored as safetensors.
+ in_file: String,
+
+ /// The output file, either a wave audio file or some mimi tokens stored as safetensors.
+ out_file: String,
+
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ /// The model weight file, in safetensor format.
+ #[arg(long)]
+ model: Option<String>,
+}
+
+fn main() -> Result<()> {
+ let args = Args::parse();
+ let device = candle_examples::device(args.cpu)?;
+ let model = match args.model {
+ Some(model) => std::path::PathBuf::from(model),
+ None => Api::new()?
+ .model("kyutai/mimi".to_string())
+ .get("model.safetensors")?,
+ };
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
+ let config = Config::v0_1(None);
+ let mut model = Model::new(config, vb)?;
+
+ let codes = match args.action {
+ Action::CodeToAudio => {
+ let codes = candle::safetensors::load(args.in_file, &device)?;
+ codes.get("codes").expect("no codes in input file").clone()
+ }
+ Action::AudioToCode | Action::AudioToAudio => {
+ let pcm = if args.in_file == "-" {
+ println!(">>>> RECORDING AUDIO, PRESS ENTER ONCE DONE <<<<");
+ let (stream, input_audio) = audio_io::setup_input_stream()?;
+ let mut pcms = vec![];
+ let stdin = std::thread::spawn(|| {
+ let mut s = String::new();
+ std::io::stdin().read_line(&mut s)
+ });
+ while !stdin.is_finished() {
+ let input = input_audio.lock().unwrap().take_all();
+ if input.is_empty() {
+ std::thread::sleep(std::time::Duration::from_millis(100));
+ continue;
+ }
+ pcms.push(input)
+ }
+ drop(stream);
+ pcms.concat()
+ } else {
+ let (pcm, sample_rate) = audio_io::pcm_decode(args.in_file)?;
+ if sample_rate != 24_000 {
+ println!("WARNING: mimi uses a 24khz sample rate, input uses {sample_rate}, resampling...");
+ audio_io::resample(&pcm, sample_rate as usize, 24_000)?
+ } else {
+ pcm
+ }
+ };
+ let pcm_len = pcm.len();
+ let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
+ println!("input pcm shape: {:?}", pcm.shape());
+ model.encode(&pcm)?
+ }
+ };
+ println!("codes shape: {:?}", codes.shape());
+
+ match args.action {
+ Action::AudioToCode => {
+ codes.save_safetensors("codes", &args.out_file)?;
+ }
+ Action::AudioToAudio | Action::CodeToAudio => {
+ let pcm = model.decode(&codes)?;
+ println!("output pcm shape: {:?}", pcm.shape());
+ let pcm = pcm.i(0)?.i(0)?;
+ let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
+ let pcm = pcm.to_vec1::<f32>()?;
+ if args.out_file == "-" {
+ let (stream, ad) = audio_io::setup_output_stream()?;
+ {
+ let mut ad = ad.lock().unwrap();
+ ad.push_samples(&pcm)?;
+ }
+ loop {
+ let ad = ad.lock().unwrap();
+ if ad.is_empty() {
+ break;
+ }
+ // That's very weird, calling thread::sleep here triggers the stream to stop
+ // playing (the callback doesn't seem to be called anymore).
+ // std::thread::sleep(std::time::Duration::from_millis(100));
+ }
+ drop(stream)
+ } else {
+ let mut output = std::fs::File::create(&args.out_file)?;
+ candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
+ }
+ }
+ }
+ Ok(())
+}
diff --git a/candle-transformers/src/models/mimi/conv.rs b/candle-transformers/src/models/mimi/conv.rs
new file mode 100644
index 00000000..87e9fb4c
--- /dev/null
+++ b/candle-transformers/src/models/mimi/conv.rs
@@ -0,0 +1,670 @@
+// Copyright (c) Kyutai, all rights reserved.
+// This source code is licensed under the license found in the
+// LICENSE file in the root directory of this source tree.
+
+use candle::{Module, Result, StreamTensor, StreamingModule, Tensor, D};
+use candle_nn::{Conv1d, VarBuilder};
+
+#[allow(clippy::enum_variant_names)]
+#[derive(Debug, Copy, Clone, PartialEq, Eq)]
+pub enum Norm {
+ WeightNorm,
+ SpectralNorm,
+ TimeGroupNorm,
+}
+
+#[derive(Debug, Copy, Clone, PartialEq, Eq)]
+pub enum PadMode {
+ Constant,
+ Reflect,
+ Replicate,
+}
+
+// Applies weight norm for inference by recomputing the weight tensor. This
+// does not apply to training.
+// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
+fn conv1d_weight_norm(
+ in_c: usize,
+ out_c: usize,
+ kernel_size: usize,
+ bias: bool,
+ config: candle_nn::Conv1dConfig,
+ vb: VarBuilder,
+) -> Result<Conv1d> {
+ let weight = if vb.contains_tensor("weight") {
+ vb.get((out_c, in_c, kernel_size), "weight")?
+ } else {
+ let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
+ let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
+ let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
+ weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?
+ };
+ let bias = if bias {
+ Some(vb.get(out_c, "bias")?)
+ } else {
+ None
+ };
+ Ok(Conv1d::new(weight, bias, config))
+}
+
+#[derive(Debug, Clone)]
+pub struct NormConv1d {
+ conv: Conv1d,
+ norm: Option<candle_nn::GroupNorm>,
+ span: tracing::Span,
+}
+
+impl NormConv1d {
+ #[allow(clippy::too_many_arguments)]
+ pub fn new(
+ in_c: usize,
+ out_c: usize,
+ k_size: usize,
+ causal: bool,
+ norm: Option<Norm>,
+ bias: bool,
+ cfg: candle_nn::Conv1dConfig,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let conv = match norm {
+ None | Some(Norm::TimeGroupNorm) => {
+ if bias {
+ candle_nn::conv1d(in_c, out_c, k_size, cfg, vb.pp("conv"))?
+ } else {
+ candle_nn::conv1d_no_bias(in_c, out_c, k_size, cfg, vb.pp("conv"))?
+ }
+ }
+ Some(Norm::WeightNorm) => {
+ conv1d_weight_norm(in_c, out_c, k_size, bias, cfg, vb.pp("conv"))?
+ }
+ Some(Norm::SpectralNorm) => candle::bail!("SpectralNorm is not supported yet."),
+ };
+ let norm = match norm {
+ None | Some(Norm::WeightNorm) | Some(Norm::SpectralNorm) => None,
+ Some(Norm::TimeGroupNorm) => {
+ if causal {
+ candle::bail!("GroupNorm doesn't support causal evaluation.")
+ }
+ let norm = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
+ Some(norm)
+ }
+ };
+ Ok(Self {
+ conv,
+ norm,
+ span: tracing::span!(tracing::Level::TRACE, "norm-conv1d"),
+ })
+ }
+}
+
+impl Module for NormConv1d {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let xs = xs.apply(&self.conv)?;
+ match self.norm.as_ref() {
+ None => Ok(xs),
+ Some(norm) => xs.apply(norm),
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct NormConvTranspose1d {
+ ws: Tensor,
+ bs: Option<Tensor>,
+ k_size: usize,
+ stride: usize,
+ groups: usize,
+ norm: Option<candle_nn::GroupNorm>,
+ span: tracing::Span,
+}
+
+impl NormConvTranspose1d {
+ #[allow(clippy::too_many_arguments)]
+ pub fn new(
+ in_c: usize,
+ out_c: usize,
+ k_size: usize,
+ causal: bool,
+ norm: Option<Norm>,
+ bias: bool,
+ stride: usize,
+ groups: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let vb = vb.pp("conv");
+ let bs = if bias {
+ Some(vb.get(out_c, "bias")?)
+ } else {
+ None
+ };
+ let ws = match norm {
+ None | Some(Norm::TimeGroupNorm) => vb.get((in_c, out_c / groups, k_size), "weight")?,
+ Some(Norm::WeightNorm) => {
+ if vb.contains_tensor("weight") {
+ vb.get((in_c, out_c, k_size), "weight")?
+ } else {
+ let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
+ let weight_v = vb.get((in_c, out_c, k_size), "weight_v")?;
+ let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
+ weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?
+ }
+ }
+ Some(Norm::SpectralNorm) => candle::bail!("SpectralNorm is not supported yet."),
+ };
+ let (ws, groups) = if groups == out_c && in_c == out_c {
+ let eye = Tensor::eye(out_c, ws.dtype(), ws.device())?;
+ let ws = ws
+ .repeat((1, out_c, 1))?
+ .mul(&eye.unsqueeze(2)?.repeat((1, 1, k_size))?)?;
+ (ws, 1)
+ } else {
+ (ws, groups)
+ };
+ let norm = match norm {
+ None | Some(Norm::WeightNorm) | Some(Norm::SpectralNorm) => None,
+ Some(Norm::TimeGroupNorm) => {
+ if causal {
+ candle::bail!("GroupNorm doesn't support causal evaluation.")
+ }
+ let norm = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
+ Some(norm)
+ }
+ };
+ Ok(Self {
+ ws,
+ bs,
+ k_size,
+ stride,
+ groups,
+ norm,
+ span: tracing::span!(tracing::Level::TRACE, "norm-conv-tr1d"),
+ })
+ }
+}
+
+impl Module for NormConvTranspose1d {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ // conv-transpose1d seems to be broken on metal after enough iterations. Causing
+ // the following error:
+ // _status < MTLCommandBufferStatusCommitted >
+ // -[IOGPUMetalCommandBuffer setCurrentCommandEncoder:]
+ // This is now fixed in candle.
+ let xs = Tensor::conv_transpose1d(xs, &self.ws, 0, 0, self.stride, 1, self.groups)?;
+ let xs = match &self.bs {
+ None => xs,
+ Some(bias) => {
+ let b = bias.dims1()?;
+ let bias = bias.reshape((1, b, 1))?;
+ xs.broadcast_add(&bias)?
+ }
+ };
+ match self.norm.as_ref() {
+ None => Ok(xs),
+ Some(norm) => xs.apply(norm),
+ }
+ }
+}
+
+fn get_extra_padding_for_conv1d(
+ xs: &Tensor,
+ k_size: usize,
+ stride: usize,
+ padding_total: usize,
+) -> Result<usize> {
+ let len = xs.dim(D::Minus1)?;
+ let n_frames = (len + padding_total).saturating_sub(k_size) as f64 / stride as f64 + 1.0;
+ let ideal_len =
+ ((n_frames.ceil() as usize - 1) * stride + k_size).saturating_sub(padding_total);
+ Ok(ideal_len.saturating_sub(len))
+}
+
+fn pad1d(xs: &Tensor, pad_l: usize, pad_r: usize, mode: PadMode) -> Result<Tensor> {
+ match mode {
+ PadMode::Constant => xs.pad_with_zeros(D::Minus1, pad_l, pad_r),
+ PadMode::Reflect => candle::bail!("pad-mode 'reflect' is not supported"),
+ PadMode::Replicate => xs.pad_with_same(D::Minus1, pad_l, pad_r),
+ }
+}
+
+fn unpad1d(xs: &Tensor, unpad_l: usize, unpad_r: usize) -> Result<Tensor> {
+ let len = xs.dim(D::Minus1)?;
+ if len < unpad_l + unpad_r {
+ candle::bail!("unpad1d: tensor len {len} is too low, {unpad_l} + {unpad_r}")
+ }
+ xs.narrow(D::Minus1, unpad_l, len - (unpad_l + unpad_r))
+}
+
+#[derive(Debug, Clone)]
+pub struct StreamableConv1d {
+ conv: NormConv1d,
+ causal: bool,
+ pad_mode: PadMode,
+ state_prev_xs: StreamTensor,
+ left_pad_applied: bool,
+ kernel_size: usize,
+ span: tracing::Span,
+}
+
+impl StreamableConv1d {
+ #[allow(clippy::too_many_arguments)]
+ pub fn new(
+ in_c: usize,
+ out_c: usize,
+ k_size: usize,
+ stride: usize,
+ dilation: usize,
+ groups: usize,
+ bias: bool,
+ causal: bool,
+ norm: Option<Norm>,
+ pad_mode: PadMode,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let cfg = candle_nn::Conv1dConfig {
+ padding: 0,
+ stride,
+ dilation,
+ groups,
+ };
+ let conv = NormConv1d::new(in_c, out_c, k_size, causal, norm, bias, cfg, vb)?;
+ if k_size < stride {
+ candle::bail!("kernel-size {k_size} is smaller than stride {stride}")
+ }
+ Ok(Self {
+ conv,
+ causal,
+ pad_mode,
+ state_prev_xs: StreamTensor::empty(),
+ left_pad_applied: false,
+ kernel_size: k_size,
+ span: tracing::span!(tracing::Level::TRACE, "streamable-conv1d"),
+ })
+ }
+}
+
+impl Module for StreamableConv1d {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let (_b, _t, _c) = xs.dims3()?;
+ let k_size = self.conv.conv.weight().dim(D::Minus1)?;
+ let conv_cfg = self.conv.conv.config();
+ // Effective kernel size with dilations.
+ let k_size = (k_size - 1) * conv_cfg.dilation + 1;
+ let padding_total = k_size - conv_cfg.stride;
+ let extra_padding =
+ get_extra_padding_for_conv1d(xs, k_size, conv_cfg.stride, padding_total)?;
+ let xs = if self.causal {
+ pad1d(xs, padding_total, extra_padding, self.pad_mode)?
+ } else {
+ let padding_right = padding_total / 2;
+ let padding_left = padding_total - padding_right;
+ pad1d(
+ xs,
+ padding_left,
+ padding_right + extra_padding,
+ self.pad_mode,
+ )?
+ };
+ xs.apply(&self.conv)
+ }
+}
+
+impl StreamingModule for StreamableConv1d {
+ fn reset_state(&mut self) {
+ self.state_prev_xs.reset();
+ self.left_pad_applied = false;
+ }
+
+ fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
+ let _enter = self.span.enter();
+ let xs = match xs.as_option() {
+ None => return Ok(().into()),
+ Some(xs) => xs.clone(),
+ };
+ let xs = if self.left_pad_applied {
+ xs
+ } else {
+ self.left_pad_applied = true;
+ let k_size = self.conv.conv.weight().dim(D::Minus1)?;
+ let conv_cfg = self.conv.conv.config();
+ let k_size = (k_size - 1) * conv_cfg.dilation + 1;
+ let padding_total = k_size - conv_cfg.stride;
+ pad1d(&xs, padding_total, 0, self.pad_mode)?
+ };
+ let cfg = self.conv.conv.config();
+ let stride = cfg.stride;
+ let dilation = cfg.dilation;
+ let kernel = (self.kernel_size - 1) * dilation + 1;
+ let xs = StreamTensor::cat2(&self.state_prev_xs, &xs.into(), D::Minus1)?;
+ let seq_len = xs.seq_len(D::Minus1)?;
+ let num_frames = (seq_len + stride).saturating_sub(kernel) / stride;
+ if num_frames > 0 {
+ let offset = num_frames * stride;
+ self.state_prev_xs = xs.narrow(D::Minus1, offset, seq_len - offset)?;
+ let in_l = (num_frames - 1) * stride + kernel;
+ let xs = xs.narrow(D::Minus1, 0, in_l)?;
+ // We apply the underlying convtr directly rather than through forward so as
+ // not to apply any padding here.
+ xs.apply(&self.conv.conv)
+ } else {
+ self.state_prev_xs = xs;
+ Ok(StreamTensor::empty())
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct StreamableConvTranspose1d {
+ convtr: NormConvTranspose1d,
+ causal: bool,
+ state_prev_ys: StreamTensor,
+ kernel_size: usize,
+ span: tracing::Span,
+}
+
+impl StreamableConvTranspose1d {
+ #[allow(clippy::too_many_arguments)]
+ pub fn new(
+ in_c: usize,
+ out_c: usize,
+ k_size: usize,
+ stride: usize,
+ groups: usize,
+ bias: bool,
+ causal: bool,
+ norm: Option<Norm>,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let convtr =
+ NormConvTranspose1d::new(in_c, out_c, k_size, causal, norm, bias, stride, groups, vb)?;
+ Ok(Self {
+ convtr,
+ causal,
+ kernel_size: k_size,
+ state_prev_ys: StreamTensor::empty(),
+ span: tracing::span!(tracing::Level::TRACE, "streamable-conv-tr1d"),
+ })
+ }
+}
+
+impl Module for StreamableConvTranspose1d {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let k_size = self.convtr.k_size;
+ let stride = self.convtr.stride;
+ let padding_total = k_size.saturating_sub(stride);
+ let xs = xs.apply(&self.convtr)?;
+ if self.causal {
+ // This corresponds to trim_right_ratio = 1.
+ unpad1d(&xs, 0, padding_total)
+ } else {
+ let padding_right = padding_total / 2;
+ let padding_left = padding_total - padding_right;
+ unpad1d(&xs, padding_left, padding_right)
+ }
+ }
+}
+
+impl StreamingModule for StreamableConvTranspose1d {
+ fn reset_state(&mut self) {
+ self.state_prev_ys.reset()
+ }
+
+ fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
+ let _enter = self.span.enter();
+ let xs = match xs.as_option() {
+ Some(xs) => xs,
+ None => return Ok(StreamTensor::empty()),
+ };
+ let stride = self.convtr.stride;
+ // We apply the underlying convtr directly rather than through forward so as
+ // not to apply any padding here.
+ let ys = self.convtr.forward(xs)?;
+ let ot = ys.dim(D::Minus1)?;
+ let ys = match self.state_prev_ys.as_option() {
+ None => ys,
+ Some(prev_ys) => {
+ let pt = prev_ys.dim(D::Minus1)?;
+ // Remove the bias as it will be applied multiple times.
+ let prev_ys = match &self.convtr.bs {
+ None => prev_ys.clone(),
+ Some(bias) => {
+ let bias = bias.reshape((1, (), 1))?;
+ prev_ys.broadcast_sub(&bias)?
+ }
+ };
+ let ys1 = (ys.narrow(D::Minus1, 0, pt)? + prev_ys)?;
+ let ys2 = ys.narrow(D::Minus1, pt, ot - pt)?;
+ Tensor::cat(&[ys1, ys2], D::Minus1)?
+ }
+ };
+ let invalid_steps = self.kernel_size - stride;
+ let (ys, prev_ys) = StreamTensor::from(ys).split(D::Minus1, ot - invalid_steps)?;
+ self.state_prev_ys = prev_ys;
+ Ok(ys)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct ConvDownsample1d {
+ conv: StreamableConv1d,
+}
+
+impl ConvDownsample1d {
+ pub fn new(
+ stride: usize,
+ dim: usize,
+ causal: bool,
+ learnt: bool,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ if !learnt {
+ candle::bail!("only learnt=true is supported")
+ }
+ let conv = StreamableConv1d::new(
+ /* in_c */ dim,
+ /* out_c */ dim,
+ /* k_size_c */ 2 * stride,
+ /* stride */ stride,
+ /* dilation */ 1,
+ /* groups */ 1, // channel_wise = false
+ /* bias */ false,
+ /* causal */ causal,
+ /* norm */ None,
+ /* pad_mode */ PadMode::Replicate,
+ vb,
+ )?;
+ Ok(Self { conv })
+ }
+}
+
+impl Module for ConvDownsample1d {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.apply(&self.conv)
+ }
+}
+
+impl StreamingModule for ConvDownsample1d {
+ fn reset_state(&mut self) {
+ self.conv.reset_state()
+ }
+
+ fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
+ self.conv.step(xs)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct ConvTrUpsample1d {
+ convtr: StreamableConvTranspose1d,
+}
+
+impl ConvTrUpsample1d {
+ pub fn new(
+ stride: usize,
+ dim: usize,
+ causal: bool,
+ learnt: bool,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ if !learnt {
+ candle::bail!("only learnt=true is supported")
+ }
+ let convtr = StreamableConvTranspose1d::new(
+ dim,
+ dim,
+ /* k_size */ 2 * stride,
+ /* stride */ stride,
+ /* groups */ dim,
+ /* bias */ false,
+ /* causal */ causal,
+ /* norm */ None,
+ vb,
+ )?;
+ Ok(Self { convtr })
+ }
+}
+
+impl Module for ConvTrUpsample1d {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.apply(&self.convtr)
+ }
+}
+
+impl StreamingModule for ConvTrUpsample1d {
+ fn reset_state(&mut self) {
+ self.convtr.reset_state()
+ }
+
+ fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
+ self.convtr.step(xs)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use candle::IndexOp;
+
+ fn run_conv1d(
+ k_size: usize,
+ stride: usize,
+ dilation: usize,
+ step_size: usize,
+ len: usize,
+ bias: bool,
+ ) -> Result<()> {
+ // TODO: We should ensure for the seed to be constant when running these tests.
+ let dev = &candle::Device::Cpu;
+ let vm = candle_nn::VarMap::new();
+ let vb = VarBuilder::from_varmap(&vm, candle::DType::F32, dev);
+ let conv1d = StreamableConv1d::new(
+ /* in_c */ 2,
+ /* out_c */ 3,
+ /* k_size */ k_size,
+ /* stride */ stride,
+ /* dilation */ dilation,
+ /* groups */ 1,
+ /* bias */ bias,
+ /* causal */ true,
+ /* norm */ None,
+ /* pad_mode */ PadMode::Constant,
+ vb,
+ )?;
+ let xs = Tensor::randn(0f32, 1., (1, 2, step_size * len), dev)?;
+ let ys = conv1d.forward(&xs)?;
+ let mut conv1d = conv1d;
+ let mut ys_steps = vec![];
+ for idx in 0..len {
+ let xs = xs.i((.., .., step_size * idx..step_size * (idx + 1)))?;
+ let ys = conv1d.step(&xs.into())?;
+ if let Some(ys) = ys.as_option() {
+ ys_steps.push(ys.clone())
+ }
+ }
+ let ys_steps = Tensor::cat(&ys_steps, D::Minus1)?;
+ let diff = (&ys - &ys_steps)?
+ .abs()?
+ .flatten_all()?
+ .max(0)?
+ .to_vec0::<f32>()?;
+ if diff > 1e-5 {
+ println!("{xs}");
+ println!("{ys}");
+ println!("{ys_steps}");
+ candle::bail!("larger diff than expected {diff}")
+ }
+ Ok(())
+ }
+
+ fn run_conv_tr1d(
+ k_size: usize,
+ stride: usize,
+ step_size: usize,
+ len: usize,
+ bias: bool,
+ ) -> Result<()> {
+ // TODO: We should ensure for the seed to be constant when running these tests.
+ let dev = &candle::Device::Cpu;
+ let vm = candle_nn::VarMap::new();
+ let vb = VarBuilder::from_varmap(&vm, candle::DType::F32, dev);
+ let conv1d = StreamableConvTranspose1d::new(
+ /* in_c */ 2, /* out_c */ 3, /* k_size */ k_size,
+ /* stride */ stride, /* groups */ 1, /* bias */ bias,
+ /* causal */ true, /* norm */ None, vb,
+ )?;
+ let xs = Tensor::randn(0f32, 1., (1, 2, step_size * len), dev)?;
+ let ys = conv1d.forward(&xs)?;
+ let mut conv1d = conv1d;
+ let mut ys_steps = vec![];
+ for idx in 0..len {
+ let xs = xs.i((.., .., step_size * idx..step_size * (idx + 1)))?;
+ let ys = conv1d.step(&xs.into())?;
+ if let Some(ys) = ys.as_option() {
+ ys_steps.push(ys.clone())
+ }
+ }
+ let ys_steps = Tensor::cat(&ys_steps, D::Minus1)?;
+ let diff = (&ys - &ys_steps)?
+ .abs()?
+ .flatten_all()?
+ .max(0)?
+ .to_vec0::<f32>()?;
+ if diff > 1e-5 {
+ println!("{xs}");
+ println!("{ys}");
+ println!("{ys_steps}");
+ candle::bail!("larger diff than expected {diff}")
+ }
+ Ok(())
+ }
+
+ #[test]
+ fn conv1d() -> Result<()> {
+ for step_size in [1, 2, 3] {
+ for bias in [false, true] {
+ run_conv1d(1, 1, 1, step_size, 5, bias)?;
+ run_conv1d(2, 1, 1, step_size, 5, bias)?;
+ run_conv1d(2, 2, 1, step_size, 6, bias)?;
+ run_conv1d(3, 2, 1, step_size, 8, bias)?;
+ run_conv1d(3, 2, 2, step_size, 8, bias)?;
+ }
+ }
+ Ok(())
+ }
+
+ #[test]
+ fn conv_tr1d() -> Result<()> {
+ for step_size in [1, 2, 3] {
+ for bias in [false, true] {
+ run_conv_tr1d(1, 1, step_size, 5, bias)?;
+ run_conv_tr1d(2, 1, step_size, 5, bias)?;
+ run_conv_tr1d(3, 1, step_size, 5, bias)?;
+ run_conv_tr1d(3, 2, step_size, 5, bias)?;
+ }
+ }
+ Ok(())
+ }
+}
diff --git a/candle-transformers/src/models/mimi/encodec.rs b/candle-transformers/src/models/mimi/encodec.rs
new file mode 100644
index 00000000..f659da3a
--- /dev/null
+++ b/candle-transformers/src/models/mimi/encodec.rs
@@ -0,0 +1,229 @@
+// Copyright (c) Kyutai, all rights reserved.
+// This source code is licensed under the license found in the
+// LICENSE file in the root directory of this source tree.
+
+use super::{conv, quantization, seanet, transformer};
+use candle::{DType, Device, Module, Result, StreamTensor, StreamingModule, Tensor};
+use candle_nn::VarBuilder;
+
+#[derive(Debug, Copy, Clone, PartialEq, Eq)]
+pub enum ResampleMethod {
+ Conv,
+ Interpolate,
+}
+
+#[derive(Debug, Clone)]
+pub struct Config {
+ pub channels: usize,
+ pub sample_rate: f64,
+ pub frame_rate: f64,
+ pub renormalize: bool,
+ pub resample_method: ResampleMethod,
+ pub seanet: seanet::Config,
+ pub transformer: transformer::Config,
+ pub quantizer_n_q: usize,
+ pub quantizer_bins: usize,
+ pub quantizer_dim: usize,
+}
+
+impl Config {
+ // /lustre/scwpod02/client/kyutai/alex/mimi_exp/xps/b7d2bd5a/.hydra/config.yaml
+ pub fn v0_1(num_codebooks: Option<usize>) -> Self {
+ let seanet_cfg = seanet::Config {
+ dimension: 512,
+ channels: 1,
+ causal: true,
+ n_filters: 64,
+ n_residual_layers: 1,
+ activation: candle_nn::Activation::Elu(1.),
+ compress: 2,
+ dilation_base: 2,
+ disable_norm_outer_blocks: 0,
+ final_activation: None,
+ kernel_size: 7,
+ residual_kernel_size: 3,
+ last_kernel_size: 3,
+ lstm: 0,
+ norm: conv::Norm::WeightNorm,
+ pad_mode: conv::PadMode::Constant,
+ ratios: vec![8, 6, 5, 4],
+ true_skip: true,
+ };
+ let transformer_cfg = transformer::Config {
+ d_model: seanet_cfg.dimension,
+ num_heads: 8,
+ num_layers: 8,
+ causal: true,
+ norm_first: true,
+ bias_ff: false,
+ bias_attn: false,
+ layer_scale: Some(0.01),
+ context: 250,
+ conv_kernel_size: 5,
+ use_conv_bias: true,
+ use_conv_block: false,
+ cross_attention: false,
+ max_period: 10000,
+ gating: None,
+ norm: super::NormType::LayerNorm,
+ positional_embedding: transformer::PositionalEmbedding::Rope,
+
+ dim_feedforward: 2048,
+ kv_repeat: 1,
+ conv_layout: true, // see builders.py
+ max_seq_len: 8192, // the transformer works at 25hz so this is ~5 mins.
+ };
+ Config {
+ channels: 1,
+ sample_rate: 24_000.,
+ frame_rate: 12.5,
+ renormalize: true,
+ resample_method: ResampleMethod::Conv,
+ seanet: seanet_cfg,
+ transformer: transformer_cfg,
+ quantizer_n_q: num_codebooks.unwrap_or(16),
+ quantizer_bins: 2048,
+ quantizer_dim: 256,
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct Encodec {
+ encoder: seanet::SeaNetEncoder,
+ decoder: seanet::SeaNetDecoder,
+ encoder_transformer: transformer::ProjectedTransformer,
+ decoder_transformer: transformer::ProjectedTransformer,
+ downsample: conv::ConvDownsample1d,
+ upsample: conv::ConvTrUpsample1d,
+ quantizer: quantization::SplitResidualVectorQuantizer,
+ config: Config,
+}
+
+impl Encodec {
+ pub fn new(cfg: Config, vb: VarBuilder) -> Result<Self> {
+ let dim = cfg.seanet.dimension;
+ let encoder = seanet::SeaNetEncoder::new(&cfg.seanet, vb.pp("encoder"))?;
+ let decoder = seanet::SeaNetDecoder::new(&cfg.seanet, vb.pp("decoder"))?;
+ let encoder_transformer = transformer::ProjectedTransformer::new(
+ dim,
+ &[dim],
+ &cfg.transformer,
+ vb.pp("encoder_transformer"),
+ )?;
+ let decoder_transformer = transformer::ProjectedTransformer::new(
+ dim,
+ &[dim],
+ &cfg.transformer,
+ vb.pp("decoder_transformer"),
+ )?;
+ let quantizer = quantization::SplitResidualVectorQuantizer::new(
+ /* dim */ cfg.quantizer_dim,
+ /* input_dim */ Some(dim),
+ /* output_dim */ Some(dim),
+ /* n_q */ cfg.quantizer_n_q,
+ /* bins */ cfg.quantizer_bins,
+ vb.pp("quantizer"),
+ )?;
+ let encoder_frame_rate =
+ cfg.sample_rate / cfg.seanet.ratios.iter().product::<usize>() as f64;
+
+ let downsample_stride = (encoder_frame_rate / cfg.frame_rate) as usize;
+ // `upsample` and `downsample` only apply if frame_rate is different from encoder_frame_rate.
+ let downsample = conv::ConvDownsample1d::new(
+ /* stride */ downsample_stride,
+ /* dim */ dim,
+ /* causal */ true,
+ /* learnt */ true,
+ vb.pp("downsample"),
+ )?;
+ let upsample = conv::ConvTrUpsample1d::new(
+ /* stride */ downsample_stride,
+ /* dim */ dim,
+ /* causal */ true,
+ /* learnt */ true,
+ vb.pp("upsample"),
+ )?;
+
+ Ok(Self {
+ encoder,
+ decoder,
+ encoder_transformer,
+ decoder_transformer,
+ quantizer,
+ downsample,
+ upsample,
+ config: cfg,
+ })
+ }
+
+ pub fn config(&self) -> &Config {
+ &self.config
+ }
+
+ pub fn encode_pre_quantize(&mut self, xs: &Tensor) -> Result<Tensor> {
+ let xs = self.encoder.forward(xs)?;
+ self.encoder_transformer.reset_state();
+ let xs = self.encoder_transformer.forward(&xs)?;
+ let xs = &xs[0];
+ xs.apply(&self.downsample)
+ }
+
+ pub fn encode(&mut self, xs: &Tensor) -> Result<Tensor> {
+ let xs = self.encoder.forward(xs)?;
+ self.encoder_transformer.reset_state();
+ let xs = self.encoder_transformer.forward(&xs)?;
+ let xs = &xs[0];
+ let xs = xs.apply(&self.downsample)?;
+ let codes = self.quantizer.encode(&xs)?;
+ Ok(codes)
+ }
+
+ pub fn encode_step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
+ let xs = self.encoder.step(xs)?;
+ let xs = self.encoder_transformer.step(&xs)?;
+ let xs = self.downsample.step(&xs)?;
+ match xs.as_option() {
+ None => Ok(().into()),
+ Some(xs) => {
+ let codes = self.quantizer.encode(xs)?;
+ Ok(codes.into())
+ }
+ }
+ }
+
+ pub fn decode(&mut self, codes: &Tensor) -> Result<Tensor> {
+ let emb = self.quantizer.decode(codes)?;
+ let emb = emb.apply(&self.upsample)?;
+ self.decoder_transformer.reset_state();
+ let outs = self.decoder_transformer.forward(&emb)?;
+ let out = &outs[0];
+ self.decoder.forward(out)
+ }
+
+ pub fn decode_step(&mut self, codes: &StreamTensor) -> Result<StreamTensor> {
+ let emb = match codes.as_option() {
+ Some(codes) => StreamTensor::from_tensor(self.quantizer.decode(codes)?),
+ None => StreamTensor::empty(),
+ };
+ let emb = self.upsample.step(&emb)?;
+ let out = self.decoder_transformer.step(&emb)?;
+ self.decoder.step(&out)
+ }
+
+ pub fn reset_state(&mut self) {
+ self.encoder.reset_state();
+ self.encoder_transformer.reset_state();
+ self.decoder.reset_state();
+ self.decoder_transformer.reset_state();
+ self.upsample.reset_state();
+ }
+}
+
+pub fn load(model_file: &str, num_codebooks: Option<usize>, dev: &Device) -> Result<Encodec> {
+ let vb =
+ unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, dev)? };
+ let cfg = Config::v0_1(num_codebooks);
+ let encodec = Encodec::new(cfg, vb)?;
+ Ok(encodec)
+}
diff --git a/candle-transformers/src/models/mimi/mod.rs b/candle-transformers/src/models/mimi/mod.rs
new file mode 100644
index 00000000..dc40e38e
--- /dev/null
+++ b/candle-transformers/src/models/mimi/mod.rs
@@ -0,0 +1,22 @@
+// Adapted from the reference implementation at:
+// https://github.com/kyutai-labs/moshi
+// Copyright (c) Kyutai, all rights reserved.
+// This source code is licensed under the license found in the
+// LICENSE file in the root directory of this source tree.
+
+pub use candle;
+pub use candle_nn;
+
+pub mod conv;
+pub mod encodec;
+pub mod quantization;
+pub mod seanet;
+pub mod transformer;
+
+#[derive(Debug, Copy, Clone, PartialEq, Eq)]
+pub enum NormType {
+ RmsNorm,
+ LayerNorm,
+}
+
+pub use encodec::{load, Config, Encodec as Model};
diff --git a/candle-transformers/src/models/mimi/quantization.rs b/candle-transformers/src/models/mimi/quantization.rs
new file mode 100644
index 00000000..3fde1647
--- /dev/null
+++ b/candle-transformers/src/models/mimi/quantization.rs
@@ -0,0 +1,404 @@
+// Copyright (c) Kyutai, all rights reserved.
+// This source code is licensed under the license found in the
+// LICENSE file in the root directory of this source tree.
+
+use candle::{IndexOp, Layout, Result, Shape, Tensor, D};
+use candle_nn::{linear, Linear, VarBuilder};
+
+struct CodebookEncode;
+
+impl candle::CustomOp2 for CodebookEncode {
+ fn name(&self) -> &'static str {
+ "cb"
+ }
+
+ fn cpu_fwd(
+ &self,
+ lhs_storage: &candle::CpuStorage,
+ lhs_layout: &Layout,
+ rhs_storage: &candle::CpuStorage,
+ rhs_layout: &Layout,
+ ) -> Result<(candle::CpuStorage, Shape)> {
+ use rayon::prelude::*;
+
+ let (lhs_dim1, lhs_dim2) = lhs_layout.shape().dims2()?;
+ let (rhs_dim1, rhs_dim2) = rhs_layout.shape().dims2()?;
+ if lhs_dim2 != rhs_dim2 {
+ candle::bail!("CodebookEncode, mismatch on last dim, {lhs_layout:?} {rhs_layout:?}");
+ }
+ if lhs_dim2 == 0 {
+ candle::bail!("CodebookEncode, empty last dim {lhs_layout:?}")
+ }
+ let lhs = match lhs_layout.contiguous_offsets() {
+ None => candle::bail!("CodebookEncode, lhs has to be contiguous, got {lhs_layout:?}"),
+ Some((o1, o2)) => {
+ let slice = lhs_storage.as_slice::<f32>()?;
+ &slice[o1..o2]
+ }
+ };
+ let rhs = match rhs_layout.contiguous_offsets() {
+ None => candle::bail!("CodebookEncode, rhs has to be contiguous, got {rhs_layout:?}"),
+ Some((o1, o2)) => {
+ let slice = rhs_storage.as_slice::<f32>()?;
+ &slice[o1..o2]
+ }
+ };
+ let dst = (0..lhs_dim1)
+ .into_par_iter()
+ .map(|idx1| {
+ let mut where_min = 0;
+ let mut min_dist = f32::INFINITY;
+ let lhs = &lhs[idx1 * lhs_dim2..(idx1 + 1) * lhs_dim2];
+ for idx2 in 0..rhs_dim1 {
+ let rhs = &rhs[idx2 * rhs_dim2..(idx2 + 1) * rhs_dim2];
+ let mut dist = 0f32;
+ for (a, b) in lhs.iter().zip(rhs.iter()) {
+ dist += (a - b) * (a - b)
+ }
+ if dist < min_dist {
+ min_dist = dist;
+ where_min = idx2;
+ }
+ }
+ where_min as u32
+ })
+ .collect();
+ let storage = candle::WithDType::to_cpu_storage_owned(dst);
+ Ok((storage, (lhs_dim1,).into()))
+ }
+}
+
+#[allow(unused)]
+#[derive(Debug, Clone)]
+pub struct EuclideanCodebook {
+ initialized: Tensor,
+ cluster_usage: Tensor,
+ embedding_sum: Tensor,
+ embedding: Tensor,
+ c2: Tensor,
+ epsilon: f64,
+ dim: usize,
+ span_encode: tracing::Span,
+ span_decode: tracing::Span,
+}
+
+impl EuclideanCodebook {
+ pub fn new(dim: usize, codebook_size: usize, vb: VarBuilder) -> Result<Self> {
+ let epsilon = 1e-5;
+ let initialized = vb.get(1, "initialized")?;
+ let cluster_usage = vb.get(codebook_size, "cluster_usage")?;
+ let embedding_sum = vb.get((codebook_size, dim), "embed_sum")?;
+ let embedding = {
+ let cluster_usage = cluster_usage.maximum(epsilon)?.unsqueeze(1)?;
+ embedding_sum.broadcast_div(&cluster_usage)?
+ };
+ let c2 = ((&embedding * &embedding)?.sum(D::Minus1)? / 2.0)?;
+ Ok(Self {
+ initialized,
+ cluster_usage,
+ embedding_sum,
+ embedding,
+ c2,
+ epsilon,
+ dim,
+ span_encode: tracing::span!(tracing::Level::TRACE, "euclidean-encode"),
+ span_decode: tracing::span!(tracing::Level::TRACE, "euclidean-encode"),
+ })
+ }
+
+ pub fn encode_very_slow(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span_encode.enter();
+ let mut target_shape = xs.dims().to_vec();
+ target_shape.pop();
+ let xs = xs.flatten_to(D::Minus2)?;
+ let _ = xs.dims2()?;
+ // TODO: avoid repeating this.
+ let cluster_usage = self.cluster_usage.maximum(self.epsilon)?.unsqueeze(1)?;
+ let embedding = self.embedding_sum.broadcast_div(&cluster_usage)?;
+ // Manual cdist implementation.
+ let diff = xs.unsqueeze(1)?.broadcast_sub(&embedding.unsqueeze(0)?)?;
+ let dists = diff.sqr()?.sum(D::Minus1)?;
+ let codes = dists.argmin(D::Minus1)?;
+ codes.reshape(target_shape)
+ }
+
+ pub fn encode_slow(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span_encode.enter();
+ let mut target_shape = xs.dims().to_vec();
+ target_shape.pop();
+ let xs = xs.flatten_to(D::Minus2)?;
+ let _ = xs.dims2()?;
+ let dot_prod = xs.matmul(&self.embedding.t()?)?;
+ let codes = self.c2.broadcast_sub(&dot_prod)?.argmin(D::Minus1)?;
+ codes.reshape(target_shape)
+ }
+
+ pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span_encode.enter();
+ let mut target_shape = xs.dims().to_vec();
+ target_shape.pop();
+ let xs = xs.flatten_to(D::Minus2)?;
+ let _ = xs.dims2()?;
+ let codes = Tensor::apply_op2(&xs, &self.embedding, CodebookEncode)?;
+ codes.reshape(target_shape)
+ }
+
+ pub fn decode(&self, indexes: &Tensor) -> Result<Tensor> {
+ let _enter = self.span_decode.enter();
+ // let ys = candle_nn::Embedding::new(self.embedding.clone(), self.dim).forward(xs)?;
+ let mut final_dims = indexes.dims().to_vec();
+ final_dims.push(self.dim);
+ let indexes = indexes.flatten_all()?;
+ let values = self.embedding.index_select(&indexes, 0)?;
+ let values = values.reshape(final_dims)?;
+ Ok(values)
+ }
+}
+
+#[allow(unused)]
+#[derive(Debug, Clone)]
+pub struct VectorQuantization {
+ project_in: Option<Linear>,
+ project_out: Option<Linear>,
+ codebook: EuclideanCodebook,
+}
+
+impl VectorQuantization {
+ pub fn new(
+ dim: usize,
+ codebook_size: usize,
+ codebook_dim: Option<usize>,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let codebook_dim = codebook_dim.unwrap_or(dim);
+ let (project_in, project_out) = if codebook_dim == dim {
+ (None, None)
+ } else {
+ let p_in = linear(dim, codebook_dim, vb.pp("project_in"))?;
+ let p_out = linear(codebook_dim, dim, vb.pp("project_out"))?;
+ (Some(p_in), Some(p_out))
+ };
+ let codebook = EuclideanCodebook::new(codebook_dim, codebook_size, vb.pp("codebook"))?;
+ Ok(Self {
+ project_in,
+ project_out,
+ codebook,
+ })
+ }
+
+ pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = xs.t()?.apply(&self.project_in.as_ref())?;
+ self.codebook.encode_slow(&xs)
+ }
+
+ pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
+ let quantized = self.codebook.decode(codes)?;
+ let quantized = match &self.project_out {
+ None => quantized,
+ Some(p) => quantized.apply(p)?,
+ };
+ quantized.t()
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct ResidualVectorQuantization {
+ layers: Vec<VectorQuantization>,
+}
+
+impl ResidualVectorQuantization {
+ pub fn new(
+ n_q: usize,
+ dim: usize,
+ codebook_size: usize,
+ codebook_dim: Option<usize>,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let vb = vb.pp("layers");
+ let mut layers = Vec::with_capacity(n_q);
+ for i in 0..n_q {
+ let layer = VectorQuantization::new(dim, codebook_size, codebook_dim, vb.pp(i))?;
+ layers.push(layer)
+ }
+ Ok(Self { layers })
+ }
+
+ pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut codes = Vec::with_capacity(self.layers.len());
+ let mut residual = xs.clone();
+ for layer in self.layers.iter() {
+ let indices = layer.encode(&residual)?;
+ let quantized = layer.decode(&indices)?;
+ residual = (residual - quantized)?;
+ codes.push(indices)
+ }
+ Tensor::stack(&codes, 0)
+ }
+
+ pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
+ if self.layers.is_empty() {
+ candle::bail!("empty layers in ResidualVectorQuantization")
+ }
+ if self.layers.len() != xs.dim(0)? {
+ candle::bail!(
+ "mismatch between the number of layers {} and the code shape {:?}",
+ self.layers.len(),
+ xs.shape()
+ )
+ }
+ let mut quantized = self.layers[0].decode(&xs.i(0)?)?;
+ for (i, layer) in self.layers.iter().enumerate().skip(1) {
+ let xs = xs.i(i)?;
+ quantized = (quantized + layer.decode(&xs))?
+ }
+ Ok(quantized)
+ }
+}
+
+#[allow(unused)]
+#[derive(Debug, Clone)]
+pub struct ResidualVectorQuantizer {
+ vq: ResidualVectorQuantization,
+ input_proj: Option<candle_nn::Conv1d>,
+ output_proj: Option<candle_nn::Conv1d>,
+}
+
+impl ResidualVectorQuantizer {
+ pub fn new(
+ dim: usize,
+ input_dim: Option<usize>,
+ output_dim: Option<usize>,
+ n_q: usize,
+ bins: usize,
+ force_projection: bool,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let input_dim = input_dim.unwrap_or(dim);
+ let output_dim = output_dim.unwrap_or(dim);
+
+ let input_proj = if input_dim == dim && !force_projection {
+ None
+ } else {
+ let c = candle_nn::conv1d_no_bias(
+ input_dim,
+ dim,
+ 1,
+ Default::default(),
+ vb.pp("input_proj"),
+ )?;
+ Some(c)
+ };
+ let output_proj = if output_dim == dim && !force_projection {
+ None
+ } else {
+ let c = candle_nn::conv1d_no_bias(
+ dim,
+ output_dim,
+ 1,
+ Default::default(),
+ vb.pp("output_proj"),
+ )?;
+ Some(c)
+ };
+
+ let vq = ResidualVectorQuantization::new(
+ n_q, dim, /* codebook_size */ bins, /* codebook_dim */ None, vb,
+ )?;
+ Ok(Self {
+ vq,
+ input_proj,
+ output_proj,
+ })
+ }
+
+ pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
+ let codes = self.vq.encode(&xs.apply(&self.input_proj.as_ref())?)?;
+ codes.transpose(0, 1)
+ }
+
+ pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
+ // codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
+ let codes = codes.transpose(0, 1)?;
+ let quantized = self.vq.decode(&codes)?;
+ match &self.output_proj {
+ None => Ok(quantized),
+ Some(p) => quantized.apply(p),
+ }
+ }
+}
+
+// we do not use any codebook_offset at the moment. When reconstructing the codes, we could just
+// concatenate the indexes.
+#[derive(Debug, Clone)]
+pub struct SplitResidualVectorQuantizer {
+ rvq_first: ResidualVectorQuantizer,
+ rvq_rest: ResidualVectorQuantizer,
+ n_q: usize,
+ span_encode: tracing::Span,
+ span_decode: tracing::Span,
+}
+
+impl SplitResidualVectorQuantizer {
+ pub fn new(
+ dim: usize,
+ input_dim: Option<usize>,
+ output_dim: Option<usize>,
+ n_q: usize,
+ bins: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let rvq_first = ResidualVectorQuantizer::new(
+ dim,
+ input_dim,
+ output_dim,
+ 1,
+ bins,
+ true,
+ vb.pp("semantic_residual_vector_quantizer"),
+ )?;
+ let rvq_rest = ResidualVectorQuantizer::new(
+ dim,
+ input_dim,
+ output_dim,
+ n_q - 1,
+ bins,
+ true,
+ vb.pp("acoustic_residual_vector_quantizer"),
+ )?;
+ let span_encode = tracing::span!(tracing::Level::TRACE, "split-rvq-encode");
+ let span_decode = tracing::span!(tracing::Level::TRACE, "split-rvq-decode");
+ Ok(Self {
+ rvq_first,
+ rvq_rest,
+ n_q,
+ span_encode,
+ span_decode,
+ })
+ }
+
+ pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span_encode.enter();
+ let codes = self.rvq_first.encode(xs)?;
+ if self.n_q > 1 {
+ // We encode xs again here rather than the residual. The decomposition is not
+ // hierarchical but rather having semantic tokens for rvq_first and the acoustic tokens
+ // for rvq_rest.
+ let rest_codes = self.rvq_rest.encode(xs)?;
+ Tensor::cat(&[codes, rest_codes], 1)
+ } else {
+ Ok(codes)
+ }
+ }
+
+ pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
+ // codes is [B, K, T], with T frames, K nb of codebooks.
+ let _enter = self.span_decode.enter();
+ let quantized = self.rvq_first.decode(&codes.i((.., ..1))?)?;
+ let quantized = if self.n_q > 1 {
+ (quantized + self.rvq_rest.decode(&codes.i((.., 1..))?))?
+ } else {
+ quantized
+ };
+ Ok(quantized)
+ }
+}
diff --git a/candle-transformers/src/models/mimi/seanet.rs b/candle-transformers/src/models/mimi/seanet.rs
new file mode 100644
index 00000000..aa5c7d21
--- /dev/null
+++ b/candle-transformers/src/models/mimi/seanet.rs
@@ -0,0 +1,465 @@
+// Copyright (c) Kyutai, all rights reserved.
+// This source code is licensed under the license found in the
+// LICENSE file in the root directory of this source tree.
+
+use candle::{streaming, Module, Result, StreamTensor, StreamingModule, Tensor};
+use candle_nn::VarBuilder;
+
+use super::conv::{StreamableConv1d, StreamableConvTranspose1d};
+
+#[derive(Debug, Clone)]
+pub struct Config {
+ pub dimension: usize,
+ pub channels: usize,
+ pub causal: bool,
+ pub n_filters: usize,
+ pub n_residual_layers: usize,
+ pub ratios: Vec<usize>,
+ pub activation: candle_nn::Activation,
+ pub norm: super::conv::Norm,
+ pub kernel_size: usize,
+ pub residual_kernel_size: usize,
+ pub last_kernel_size: usize,
+ pub dilation_base: usize,
+ pub pad_mode: super::conv::PadMode,
+ pub true_skip: bool,
+ pub compress: usize,
+ pub lstm: usize,
+ pub disable_norm_outer_blocks: usize,
+ pub final_activation: Option<candle_nn::Activation>,
+}
+
+#[derive(Debug, Clone)]
+pub struct SeaNetResnetBlock {
+ block: Vec<StreamableConv1d>,
+ shortcut: Option<StreamableConv1d>,
+ activation: candle_nn::Activation,
+ skip_op: candle::StreamingBinOp,
+ span: tracing::Span,
+}
+
+impl SeaNetResnetBlock {
+ #[allow(clippy::too_many_arguments)]
+ pub fn new(
+ dim: usize,
+ k_sizes_and_dilations: &[(usize, usize)],
+ activation: candle_nn::Activation,
+ norm: Option<super::conv::Norm>,
+ causal: bool,
+ pad_mode: super::conv::PadMode,
+ compress: usize,
+ true_skip: bool,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let mut block = Vec::with_capacity(k_sizes_and_dilations.len());
+ let hidden = dim / compress;
+ let vb_b = vb.pp("block");
+ for (i, (k_size, dilation)) in k_sizes_and_dilations.iter().enumerate() {
+ let in_c = if i == 0 { dim } else { hidden };
+ let out_c = if i == k_sizes_and_dilations.len() - 1 {
+ dim
+ } else {
+ hidden
+ };
+ let c = StreamableConv1d::new(
+ in_c,
+ out_c,
+ /* k_size */ *k_size,
+ /* stride */ 1,
+ /* dilation */ *dilation,
+ /* groups */ 1,
+ /* bias */ true,
+ /* causal */ causal,
+ /* norm */ norm,
+ /* pad_mode */ pad_mode,
+ vb_b.pp(2 * i + 1),
+ )?;
+ block.push(c)
+ }
+ let shortcut = if true_skip {
+ None
+ } else {
+ let c = StreamableConv1d::new(
+ dim,
+ dim,
+ /* k_size */ 1,
+ /* stride */ 1,
+ /* dilation */ 1,
+ /* groups */ 1,
+ /* bias */ true,
+ /* causal */ causal,
+ /* norm */ norm,
+ /* pad_mode */ pad_mode,
+ vb.pp("shortcut"),
+ )?;
+ Some(c)
+ };
+ Ok(Self {
+ block,
+ shortcut,
+ activation,
+ skip_op: streaming::StreamingBinOp::new(streaming::BinOp::Add, candle::D::Minus1),
+ span: tracing::span!(tracing::Level::TRACE, "sea-resnet"),
+ })
+ }
+}
+
+impl Module for SeaNetResnetBlock {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let mut ys = xs.clone();
+ for block in self.block.iter() {
+ ys = ys.apply(&self.activation)?.apply(block)?;
+ }
+ match self.shortcut.as_ref() {
+ None => ys + xs,
+ Some(shortcut) => ys + xs.apply(shortcut),
+ }
+ }
+}
+
+impl StreamingModule for SeaNetResnetBlock {
+ fn reset_state(&mut self) {
+ for block in self.block.iter_mut() {
+ block.reset_state()
+ }
+ if let Some(shortcut) = self.shortcut.as_mut() {
+ shortcut.reset_state()
+ }
+ }
+
+ fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
+ let _enter = self.span.enter();
+ let mut ys = xs.clone();
+ for block in self.block.iter_mut() {
+ ys = block.step(&ys.apply(&self.activation)?)?;
+ }
+ match self.shortcut.as_ref() {
+ None => self.skip_op.step(&ys, xs),
+ Some(shortcut) => self.skip_op.step(&ys, &xs.apply(shortcut)?),
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+struct EncoderLayer {
+ residuals: Vec<SeaNetResnetBlock>,
+ downsample: StreamableConv1d,
+}
+
+#[derive(Debug, Clone)]
+pub struct SeaNetEncoder {
+ init_conv1d: StreamableConv1d,
+ activation: candle_nn::Activation,
+ layers: Vec<EncoderLayer>,
+ final_conv1d: StreamableConv1d,
+ span: tracing::Span,
+}
+
+impl SeaNetEncoder {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ if cfg.lstm > 0 {
+ candle::bail!("seanet lstm is not supported")
+ }
+ let n_blocks = 2 + cfg.ratios.len();
+ let mut mult = 1usize;
+ let init_norm = if cfg.disable_norm_outer_blocks >= 1 {
+ None
+ } else {
+ Some(cfg.norm)
+ };
+ let mut layer_idx = 0;
+ let vb = vb.pp("layers");
+ let init_conv1d = StreamableConv1d::new(
+ cfg.channels,
+ mult * cfg.n_filters,
+ cfg.kernel_size,
+ /* stride */ 1,
+ /* dilation */ 1,
+ /* groups */ 1,
+ /* bias */ true,
+ /* causal */ cfg.causal,
+ /* norm */ init_norm,
+ /* pad_mode */ cfg.pad_mode,
+ vb.pp(layer_idx),
+ )?;
+ layer_idx += 1;
+ let mut layers = Vec::with_capacity(cfg.ratios.len());
+
+ for (i, &ratio) in cfg.ratios.iter().rev().enumerate() {
+ let norm = if cfg.disable_norm_outer_blocks >= i + 2 {
+ None
+ } else {
+ Some(cfg.norm)
+ };
+ let mut residuals = Vec::with_capacity(cfg.n_residual_layers);
+ for j in 0..cfg.n_residual_layers {
+ let resnet_block = SeaNetResnetBlock::new(
+ mult * cfg.n_filters,
+ &[
+ (cfg.residual_kernel_size, cfg.dilation_base.pow(j as u32)),
+ (1, 1),
+ ],
+ cfg.activation,
+ norm,
+ cfg.causal,
+ cfg.pad_mode,
+ cfg.compress,
+ cfg.true_skip,
+ vb.pp(layer_idx),
+ )?;
+ residuals.push(resnet_block);
+ layer_idx += 1;
+ }
+ let downsample = StreamableConv1d::new(
+ mult * cfg.n_filters,
+ mult * cfg.n_filters * 2,
+ /* k_size */ ratio * 2,
+ /* stride */ ratio,
+ /* dilation */ 1,
+ /* groups */ 1,
+ /* bias */ true,
+ /* causal */ true,
+ /* norm */ norm,
+ /* pad_mode */ cfg.pad_mode,
+ vb.pp(layer_idx + 1),
+ )?;
+ layer_idx += 2;
+ let layer = EncoderLayer {
+ downsample,
+ residuals,
+ };
+ layers.push(layer);
+ mult *= 2
+ }
+
+ let final_norm = if cfg.disable_norm_outer_blocks >= n_blocks {
+ None
+ } else {
+ Some(cfg.norm)
+ };
+ let final_conv1d = StreamableConv1d::new(
+ mult * cfg.n_filters,
+ cfg.dimension,
+ cfg.last_kernel_size,
+ /* stride */ 1,
+ /* dilation */ 1,
+ /* groups */ 1,
+ /* bias */ true,
+ /* causal */ cfg.causal,
+ /* norm */ final_norm,
+ /* pad_mode */ cfg.pad_mode,
+ vb.pp(layer_idx + 1),
+ )?;
+ Ok(Self {
+ init_conv1d,
+ activation: cfg.activation,
+ layers,
+ final_conv1d,
+ span: tracing::span!(tracing::Level::TRACE, "sea-encoder"),
+ })
+ }
+}
+
+impl Module for SeaNetEncoder {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let mut xs = xs.apply(&self.init_conv1d)?;
+ for layer in self.layers.iter() {
+ for residual in layer.residuals.iter() {
+ xs = xs.apply(residual)?
+ }
+ xs = xs.apply(&self.activation)?.apply(&layer.downsample)?;
+ }
+ xs.apply(&self.activation)?.apply(&self.final_conv1d)
+ }
+}
+
+impl StreamingModule for SeaNetEncoder {
+ fn reset_state(&mut self) {
+ self.init_conv1d.reset_state();
+ self.layers.iter_mut().for_each(|v| {
+ v.residuals.iter_mut().for_each(|v| v.reset_state());
+ v.downsample.reset_state()
+ });
+ self.final_conv1d.reset_state();
+ }
+
+ fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
+ let _enter = self.span.enter();
+ let mut xs = self.init_conv1d.step(xs)?;
+ for layer in self.layers.iter_mut() {
+ for residual in layer.residuals.iter_mut() {
+ xs = residual.step(&xs)?;
+ }
+ xs = layer.downsample.step(&xs.apply(&self.activation)?)?;
+ }
+ self.final_conv1d.step(&xs.apply(&self.activation)?)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct DecoderLayer {
+ upsample: StreamableConvTranspose1d,
+ residuals: Vec<SeaNetResnetBlock>,
+}
+
+#[derive(Debug, Clone)]
+pub struct SeaNetDecoder {
+ init_conv1d: StreamableConv1d,
+ activation: candle_nn::Activation,
+ layers: Vec<DecoderLayer>,
+ final_conv1d: StreamableConv1d,
+ final_activation: Option<candle_nn::Activation>,
+ span: tracing::Span,
+}
+
+impl SeaNetDecoder {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ if cfg.lstm > 0 {
+ candle::bail!("seanet lstm is not supported")
+ }
+ let n_blocks = 2 + cfg.ratios.len();
+ let mut mult = 1 << cfg.ratios.len();
+ let init_norm = if cfg.disable_norm_outer_blocks == n_blocks {
+ None
+ } else {
+ Some(cfg.norm)
+ };
+ let mut layer_idx = 0;
+ let vb = vb.pp("layers");
+ let init_conv1d = StreamableConv1d::new(
+ cfg.dimension,
+ mult * cfg.n_filters,
+ cfg.kernel_size,
+ /* stride */ 1,
+ /* dilation */ 1,
+ /* groups */ 1,
+ /* bias */ true,
+ /* causal */ cfg.causal,
+ /* norm */ init_norm,
+ /* pad_mode */ cfg.pad_mode,
+ vb.pp(layer_idx),
+ )?;
+ layer_idx += 1;
+ let mut layers = Vec::with_capacity(cfg.ratios.len());
+ for (i, &ratio) in cfg.ratios.iter().enumerate() {
+ let norm = if cfg.disable_norm_outer_blocks + i + 1 >= n_blocks {
+ None
+ } else {
+ Some(cfg.norm)
+ };
+ let upsample = StreamableConvTranspose1d::new(
+ mult * cfg.n_filters,
+ mult * cfg.n_filters / 2,
+ /* k_size */ ratio * 2,
+ /* stride */ ratio,
+ /* groups */ 1,
+ /* bias */ true,
+ /* causal */ true,
+ /* norm */ norm,
+ vb.pp(layer_idx + 1),
+ )?;
+ layer_idx += 2;
+
+ let mut residuals = Vec::with_capacity(cfg.n_residual_layers);
+ for j in 0..cfg.n_residual_layers {
+ let resnet_block = SeaNetResnetBlock::new(
+ mult * cfg.n_filters / 2,
+ &[
+ (cfg.residual_kernel_size, cfg.dilation_base.pow(j as u32)),
+ (1, 1),
+ ],
+ cfg.activation,
+ norm,
+ cfg.causal,
+ cfg.pad_mode,
+ cfg.compress,
+ cfg.true_skip,
+ vb.pp(layer_idx),
+ )?;
+ residuals.push(resnet_block);
+ layer_idx += 1;
+ }
+ let layer = DecoderLayer {
+ upsample,
+ residuals,
+ };
+ layers.push(layer);
+ mult /= 2
+ }
+ let final_norm = if cfg.disable_norm_outer_blocks >= 1 {
+ None
+ } else {
+ Some(cfg.norm)
+ };
+ let final_conv1d = StreamableConv1d::new(
+ cfg.n_filters,
+ cfg.channels,
+ cfg.last_kernel_size,
+ /* stride */ 1,
+ /* dilation */ 1,
+ /* groups */ 1,
+ /* bias */ true,
+ /* causal */ cfg.causal,
+ /* norm */ final_norm,
+ /* pad_mode */ cfg.pad_mode,
+ vb.pp(layer_idx + 1),
+ )?;
+ Ok(Self {
+ init_conv1d,
+ activation: cfg.activation,
+ layers,
+ final_conv1d,
+ final_activation: cfg.final_activation,
+ span: tracing::span!(tracing::Level::TRACE, "sea-decoder"),
+ })
+ }
+}
+
+impl Module for SeaNetDecoder {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let mut xs = xs.apply(&self.init_conv1d)?;
+ for layer in self.layers.iter() {
+ xs = xs.apply(&self.activation)?.apply(&layer.upsample)?;
+ for residual in layer.residuals.iter() {
+ xs = xs.apply(residual)?
+ }
+ }
+ let xs = xs.apply(&self.activation)?.apply(&self.final_conv1d)?;
+ let xs = match self.final_activation.as_ref() {
+ None => xs,
+ Some(act) => xs.apply(act)?,
+ };
+ Ok(xs)
+ }
+}
+
+impl StreamingModule for SeaNetDecoder {
+ fn reset_state(&mut self) {
+ self.init_conv1d.reset_state();
+ self.layers.iter_mut().for_each(|v| {
+ v.residuals.iter_mut().for_each(|v| v.reset_state());
+ v.upsample.reset_state()
+ });
+ self.final_conv1d.reset_state();
+ }
+
+ fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
+ let _enter = self.span.enter();
+ let mut xs = self.init_conv1d.step(xs)?;
+ for layer in self.layers.iter_mut() {
+ xs = layer.upsample.step(&xs.apply(&self.activation)?)?;
+ for residual in layer.residuals.iter_mut() {
+ xs = residual.step(&xs)?;
+ }
+ }
+ let xs = self.final_conv1d.step(&xs.apply(&self.activation)?)?;
+ let xs = match self.final_activation.as_ref() {
+ None => xs,
+ Some(act) => xs.apply(act)?,
+ };
+ Ok(xs)
+ }
+}
diff --git a/candle-transformers/src/models/mimi/transformer.rs b/candle-transformers/src/models/mimi/transformer.rs
new file mode 100644
index 00000000..de221274
--- /dev/null
+++ b/candle-transformers/src/models/mimi/transformer.rs
@@ -0,0 +1,802 @@
+// Copyright (c) Kyutai, all rights reserved.
+// This source code is licensed under the license found in the
+// LICENSE file in the root directory of this source tree.
+
+use candle::{DType, Device, IndexOp, Module, Result, StreamTensor, StreamingModule, Tensor, D};
+use candle_nn::{linear_no_bias, Linear, VarBuilder};
+use std::sync::Arc;
+
+fn linear(in_d: usize, out_d: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
+ if bias {
+ candle_nn::linear(in_d, out_d, vb)
+ } else {
+ linear_no_bias(in_d, out_d, vb)
+ }
+}
+
+#[derive(Debug, Copy, Clone, PartialEq, Eq)]
+pub enum PositionalEmbedding {
+ Rope,
+ Sin,
+ None,
+}
+
+#[derive(Debug, Clone)]
+pub struct Config {
+ pub d_model: usize,
+ pub num_heads: usize,
+ pub num_layers: usize,
+ pub causal: bool,
+ pub norm_first: bool,
+ pub bias_ff: bool,
+ pub bias_attn: bool,
+ pub layer_scale: Option<f64>,
+ pub positional_embedding: PositionalEmbedding,
+ pub use_conv_block: bool,
+ pub cross_attention: bool,
+ pub conv_kernel_size: usize,
+ pub use_conv_bias: bool,
+ pub gating: Option<candle_nn::Activation>,
+ pub norm: super::NormType,
+ pub context: usize,
+ pub max_period: usize,
+ pub max_seq_len: usize,
+
+ pub kv_repeat: usize,
+ pub dim_feedforward: usize,
+ pub conv_layout: bool,
+}
+
+#[derive(Debug, Clone)]
+pub struct RotaryEmbedding {
+ sin: Tensor,
+ cos: Tensor,
+ span: tracing::Span,
+}
+
+impl RotaryEmbedding {
+ pub fn new(dim: usize, max_seq_len: usize, theta: f32, dev: &Device) -> Result<Self> {
+ let inv_freq: Vec<_> = (0..dim)
+ .step_by(2)
+ .map(|i| 1f32 / theta.powf(i as f32 / dim as f32))
+ .collect();
+ let inv_freq_len = inv_freq.len();
+ let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
+ let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
+ .to_dtype(DType::F32)?
+ .reshape((max_seq_len, 1))?;
+ let freqs = t.matmul(&inv_freq)?;
+ Ok(Self {
+ sin: freqs.sin()?,
+ cos: freqs.cos()?,
+ span: tracing::span!(tracing::Level::TRACE, "rot"),
+ })
+ }
+
+ pub fn apply_rotary_emb(&self, qk: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let (_b_size, _nheads, seqlen, _headdim) = qk.dims4()?;
+ let qk_dtype = qk.dtype();
+ let c = self.cos.narrow(0, seqlen_offset, seqlen)?;
+ let s = self.sin.narrow(0, seqlen_offset, seqlen)?;
+ candle_nn::rotary_emb::rope_i(&qk.to_dtype(DType::F32)?, &c, &s)?.to_dtype(qk_dtype)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct LayerScale {
+ scale: Tensor,
+}
+
+impl LayerScale {
+ pub fn new(d_model: usize, _init: f64, vb: VarBuilder) -> Result<Self> {
+ let scale = vb.get(d_model, "scale")?;
+ Ok(Self { scale })
+ }
+}
+
+impl Module for LayerScale {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.broadcast_mul(&self.scale)
+ }
+}
+
+pub(crate) fn get_mask(
+ size1: usize,
+ size2: usize,
+ context: usize,
+ device: &Device,
+) -> Result<Tensor> {
+ let mask: Vec<_> = (0..size1)
+ .flat_map(|i| {
+ (0..size2)
+ .map(move |j| u8::from(size1 + j > size2 + i || size1 + j + context < size2 + i))
+ })
+ .collect();
+ Tensor::from_slice(&mask, (size1, size2), device)
+}
+
+#[derive(Debug, Clone)]
+pub struct StreamingMultiheadAttention {
+ q_proj: Linear,
+ k_proj: Linear,
+ v_proj: Linear,
+ out_proj: Linear,
+ kv_repeat: usize,
+ num_heads: usize,
+ context: usize,
+ neg_inf: Tensor,
+ rope: Option<Arc<RotaryEmbedding>>,
+ kv_cache: candle_nn::kv_cache::KvCache,
+ pos: usize,
+ use_flash_attn: bool,
+ span: tracing::Span,
+}
+
+impl StreamingMultiheadAttention {
+ pub fn new(rope: &Option<Arc<RotaryEmbedding>>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let embed_dim = cfg.d_model;
+ let num_kv = cfg.num_heads / cfg.kv_repeat;
+ let kv_dim = num_kv * (embed_dim / cfg.num_heads);
+ let q_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("q_proj"))?;
+ let k_proj = linear(embed_dim, kv_dim, cfg.bias_attn, vb.pp("k_proj"))?;
+ let v_proj = linear(embed_dim, kv_dim, cfg.bias_attn, vb.pp("v_proj"))?;
+ let out_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("o_proj"))?;
+ let neg_inf = Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?;
+ Ok(Self {
+ q_proj,
+ k_proj,
+ v_proj,
+ out_proj,
+ rope: rope.clone(),
+ kv_repeat: cfg.kv_repeat,
+ num_heads: cfg.num_heads,
+ context: cfg.context,
+ neg_inf,
+ kv_cache: candle_nn::kv_cache::KvCache::new(2, cfg.max_seq_len),
+ pos: 0,
+ use_flash_attn: false,
+ span: tracing::span!(tracing::Level::TRACE, "mha"),
+ })
+ }
+
+ pub fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ if self.kv_repeat != 1 {
+ candle::bail!("only kv-repeat = 1 is supported")
+ }
+ let (b, t, hd) = xs.dims3()?;
+ let head_dim = hd / self.num_heads;
+ let q = xs
+ .apply(&self.q_proj)?
+ .reshape((b, t, self.num_heads, head_dim))?;
+ let k = xs
+ .apply(&self.k_proj)?
+ .reshape((b, t, self.num_heads, head_dim))?;
+ let v = xs
+ .apply(&self.v_proj)?
+ .reshape((b, t, self.num_heads, head_dim))?;
+ // qk_layer_norm = None
+ // kv_repeat = 1, otherwise we would need repeat_kv
+ let mut q = q.transpose(1, 2)?.contiguous()?; // b,h,t,d
+ let mut k = k.transpose(1, 2)?.contiguous()?; // b,h,k,d
+ let v = v.transpose(1, 2)?.contiguous()?; // b,h,k,d
+ if let Some(rope) = &self.rope {
+ q = rope.apply_rotary_emb(&q, self.pos)?;
+ k = rope.apply_rotary_emb(&k, self.pos)?;
+ }
+
+ let (k, v) = {
+ self.pos += k.dim(2)?;
+ self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?
+ };
+ // The KV cache keeps all the data at the moment, we want to trim
+ // down the part that comes from the cache to at most context to
+ // be coherent with the mask shape we provide.
+ let k_len = k.dim(2)?;
+ let k_target_len = t + usize::min(self.context, k_len - t);
+ let (k, v) = if k_target_len < k_len {
+ let k = k.narrow(2, k_len - k_target_len, k_target_len)?;
+ let v = v.narrow(2, k_len - k_target_len, k_target_len)?;
+ (k, v)
+ } else {
+ (k.clone(), v.clone())
+ };
+
+ let xs = if q.dtype() == DType::BF16 && self.use_flash_attn {
+ let q = q.transpose(1, 2)?;
+ let k = k.transpose(1, 2)?;
+ let v = v.transpose(1, 2)?;
+ let softmax_scale = 1f32 / (head_dim as f32).sqrt();
+ flash_attn(&q, &k, &v, softmax_scale, t > 1)?.transpose(1, 2)?
+ } else {
+ let pre_ws = q.matmul(&k.t()?)?; // b,h,t,k
+ let pre_ws = (pre_ws * (head_dim as f64).powf(-0.5))?;
+
+ let pre_ws = match mask {
+ None => pre_ws,
+ Some(mask) => {
+ let mask = mask.broadcast_left((b, self.num_heads))?;
+ let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?;
+ mask.where_cond(&neg_inf, &pre_ws)?
+ }
+ };
+
+ let ws = candle_nn::ops::softmax_last_dim(&pre_ws)?; // b,h,t,k
+ ws.matmul(&v)? // b,h,t,d
+ };
+ let xs = xs
+ .transpose(1, 2)? // b,t,h,d
+ .reshape((b, t, hd))?
+ .apply(&self.out_proj)?;
+ Ok(xs)
+ }
+
+ pub fn reset_kv_cache(&mut self) {
+ self.kv_cache.reset()
+ }
+
+ pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) {
+ self.kv_cache = kv_cache
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct StreamingMultiheadCrossAttention {
+ in_proj_q: Linear,
+ in_proj_k: Linear,
+ in_proj_v: Linear,
+ out_proj: Linear,
+ kv_repeat: usize,
+ num_heads: usize,
+ neg_inf: Tensor,
+ span: tracing::Span,
+}
+
+impl StreamingMultiheadCrossAttention {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let embed_dim = cfg.d_model;
+ let num_kv = cfg.num_heads / cfg.kv_repeat;
+ let kv_dim = num_kv * (embed_dim / cfg.num_heads);
+ let out_dim = embed_dim + 2 * kv_dim;
+ let in_proj_weight = vb.get((out_dim, embed_dim), "in_proj_weight")?;
+ let in_proj_weight_q = in_proj_weight.narrow(0, 0, embed_dim)?;
+ let in_proj_weight_k = in_proj_weight.narrow(0, embed_dim, kv_dim)?;
+ let in_proj_weight_v = in_proj_weight.narrow(0, embed_dim + kv_dim, kv_dim)?;
+ let (in_proj_bias_q, in_proj_bias_k, in_proj_bias_v) = if cfg.bias_attn {
+ let b = vb.get(out_dim, "in_proj_bias")?;
+ let q = b.narrow(0, 0, embed_dim)?;
+ let k = b.narrow(0, embed_dim, kv_dim)?;
+ let v = b.narrow(0, embed_dim + kv_dim, kv_dim)?;
+ (Some(q), Some(k), Some(v))
+ } else {
+ (None, None, None)
+ };
+ let in_proj_q = Linear::new(in_proj_weight_q, in_proj_bias_q);
+ let in_proj_k = Linear::new(in_proj_weight_k, in_proj_bias_k);
+ let in_proj_v = Linear::new(in_proj_weight_v, in_proj_bias_v);
+ let out_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("out_proj"))?;
+ let neg_inf = Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?;
+ Ok(Self {
+ in_proj_q,
+ in_proj_k,
+ in_proj_v,
+ out_proj,
+ kv_repeat: cfg.kv_repeat,
+ num_heads: cfg.num_heads,
+ neg_inf,
+ span: tracing::span!(tracing::Level::TRACE, "mhca"),
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, ca_src: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ if self.kv_repeat != 1 {
+ candle::bail!("only kv-repeat = 1 is supported")
+ }
+ let (b, t, hd) = xs.dims3()?;
+ let head_dim = hd / self.num_heads;
+ // time_dim = 1, layout: b,t,h,d
+ let q = xs.apply(&self.in_proj_q)?;
+ let k = ca_src.apply(&self.in_proj_k)?;
+ let v = ca_src.apply(&self.in_proj_v)?;
+ let (ca_b, ca_t, ca_dim) = k.dims3()?;
+ let q = q.reshape((b, t, self.num_heads, head_dim))?;
+ let k = k.reshape((ca_b, ca_t, ca_dim / head_dim, head_dim))?;
+ let v = v.reshape((ca_b, ca_t, ca_dim / head_dim, head_dim))?;
+ // qk_layer_norm = None
+ // kv_repeat = 1, otherwise we would need repeat_kv
+ let q = q.transpose(1, 2)?.contiguous()?; // b,h,t,d
+ let k = k.transpose(1, 2)?.contiguous()?; // b,h,k,d
+ let v = v.transpose(1, 2)?.contiguous()?; // b,h,k,d
+
+ let pre_ws = q.matmul(&k.t()?)?; // b,h,t,k
+ let pre_ws = (pre_ws * (head_dim as f64).powf(-0.5))?;
+
+ let pre_ws = match mask {
+ None => pre_ws,
+ Some(mask) => {
+ let mask = mask.broadcast_left((b, self.num_heads))?;
+ let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?;
+ mask.where_cond(&neg_inf, &pre_ws)?
+ }
+ };
+
+ let ws = candle_nn::ops::softmax_last_dim(&pre_ws)?; // b,h,t,k
+ let xs = ws.matmul(&v)?; // b,h,t,d
+ let xs = xs
+ .transpose(1, 2)? // b,t,h,d
+ .reshape((b, t, hd))?
+ .apply(&self.out_proj)?;
+ Ok(xs)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub enum Mlp {
+ NoGating {
+ span1: tracing::Span,
+ linear1: Linear,
+ span2: tracing::Span,
+ linear2: Linear,
+ span: tracing::Span,
+ },
+ Gating {
+ linear_in: Linear,
+ linear_out: Linear,
+ activation: candle_nn::Activation,
+ span: tracing::Span,
+ },
+}
+
+impl Mlp {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let d_model = cfg.d_model;
+ let span = tracing::span!(tracing::Level::TRACE, "mlp");
+
+ match cfg.gating {
+ None => {
+ let span1 = tracing::span!(tracing::Level::TRACE, "lin1");
+ let span2 = tracing::span!(tracing::Level::TRACE, "lin2");
+ let linear1 = linear(d_model, cfg.dim_feedforward, cfg.bias_ff, vb.pp("mlp.fc1"))?;
+ let linear2 = linear(cfg.dim_feedforward, d_model, cfg.bias_ff, vb.pp("mlp.fc2"))?;
+ Ok(Self::NoGating {
+ linear1,
+ linear2,
+ span,
+ span1,
+ span2,
+ })
+ }
+ Some(activation) => {
+ let vb = vb.pp("gating");
+ let hidden = if cfg.dim_feedforward == 4 * d_model {
+ 11 * d_model / 4
+ } else {
+ 2 * cfg.dim_feedforward / 3
+ };
+ // TODO: Maybe use bias_ff here?
+ let linear_in = linear(d_model, 2 * hidden, false, vb.pp("linear_in"))?;
+ let linear_out = linear(hidden, d_model, false, vb.pp("linear_out"))?;
+ Ok(Self::Gating {
+ linear_in,
+ linear_out,
+ activation,
+ span,
+ })
+ }
+ }
+ }
+}
+
+impl Module for Mlp {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ match self {
+ Self::NoGating {
+ linear1,
+ linear2,
+ span,
+ span1,
+ span2,
+ } => {
+ let _enter = span.enter();
+ let xs = {
+ let _enter = span1.enter();
+ xs.apply(linear1)?
+ };
+ let xs = xs.gelu_erf()?;
+ {
+ let _enter = span2.enter();
+ xs.apply(linear2)
+ }
+ }
+ Self::Gating {
+ linear_in,
+ linear_out,
+ activation,
+ span,
+ } => {
+ let _enter = span.enter();
+ let xs = xs.apply(linear_in)?;
+ let (b, t, _) = xs.dims3()?;
+ let xs = xs.reshape((b, t, 2, ()))?;
+ let xs = (xs.i((.., .., 0))?.apply(activation)? * xs.i((.., .., 1))?)?;
+ xs.apply(linear_out)
+ }
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct RmsNorm {
+ pub(crate) alpha: Tensor,
+ pub(crate) eps: f32,
+}
+
+impl RmsNorm {
+ pub fn new(d_model: usize, eps: f32, vb: VarBuilder) -> Result<Self> {
+ let alpha = vb.get((1, 1, d_model), "alpha")?.reshape(d_model)?;
+ Ok(Self { alpha, eps })
+ }
+}
+
+impl Module for RmsNorm {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ candle_nn::ops::rms_norm(xs, &self.alpha, self.eps)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub enum Norm {
+ LayerNorm(candle_nn::LayerNorm),
+ RmsNorm(RmsNorm),
+}
+
+impl Norm {
+ pub fn new(d_model: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let norm = match cfg.norm {
+ super::NormType::LayerNorm => {
+ let norm = candle_nn::layer_norm(d_model, 1e-5, vb)?;
+ Self::LayerNorm(norm)
+ }
+ super::NormType::RmsNorm => {
+ let norm = RmsNorm::new(d_model, 1e-8, vb)?;
+ Self::RmsNorm(norm)
+ }
+ };
+ Ok(norm)
+ }
+}
+
+impl Module for Norm {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ match self {
+ Self::LayerNorm(m) => m.forward(xs),
+ Self::RmsNorm(m) => m.forward(xs),
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct StreamingTransformerLayer {
+ self_attn: StreamingMultiheadAttention,
+ mlp: Mlp,
+ norm1: Norm,
+ norm2: Norm,
+ layer_scale_1: Option<LayerScale>,
+ layer_scale_2: Option<LayerScale>,
+ cross_attn: Option<(candle_nn::LayerNorm, StreamingMultiheadCrossAttention)>,
+ norm_first: bool,
+ span: tracing::Span,
+}
+
+impl StreamingTransformerLayer {
+ pub fn new(rope: &Option<Arc<RotaryEmbedding>>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ if cfg.use_conv_block {
+ candle::bail!("conv-block is not supported")
+ }
+ let d_model = cfg.d_model;
+ let mlp = Mlp::new(cfg, vb.clone())?;
+ let (norm1, norm2) = match cfg.norm {
+ super::NormType::LayerNorm => {
+ let norm1 = candle_nn::layer_norm(d_model, 1e-5, vb.pp("input_layernorm"))?;
+ let norm2 =
+ candle_nn::layer_norm(d_model, 1e-5, vb.pp("post_attention_layernorm"))?;
+ (Norm::LayerNorm(norm1), Norm::LayerNorm(norm2))
+ }
+ super::NormType::RmsNorm => {
+ let norm1 = RmsNorm::new(d_model, 1e-8, vb.pp("input_rmsnorm"))?;
+ let norm2 = RmsNorm::new(d_model, 1e-8, vb.pp("post_attention_rmsnorm"))?;
+ (Norm::RmsNorm(norm1), Norm::RmsNorm(norm2))
+ }
+ };
+ let layer_scale_1 = match cfg.layer_scale {
+ None => None,
+ Some(ls) => {
+ let ls = LayerScale::new(d_model, ls, vb.pp("self_attn_layer_scale"))?;
+ Some(ls)
+ }
+ };
+ let layer_scale_2 = match cfg.layer_scale {
+ None => None,
+ Some(ls) => {
+ let ls = LayerScale::new(d_model, ls, vb.pp("mlp_layer_scale"))?;
+ Some(ls)
+ }
+ };
+ let self_attn = StreamingMultiheadAttention::new(rope, cfg, vb.pp("self_attn"))?;
+ let cross_attn = if cfg.cross_attention {
+ let norm_cross = candle_nn::layer_norm(cfg.d_model, 1e-5, vb.pp("norm_cross"))?;
+ let cross_attn = StreamingMultiheadCrossAttention::new(cfg, vb.pp("cross_attention"))?;
+ Some((norm_cross, cross_attn))
+ } else {
+ None
+ };
+ Ok(Self {
+ self_attn,
+ mlp,
+ norm1,
+ norm2,
+ layer_scale_1,
+ layer_scale_2,
+ cross_attn,
+ norm_first: cfg.norm_first,
+ span: tracing::span!(tracing::Level::TRACE, "transformer-layer"),
+ })
+ }
+
+ pub fn forward(
+ &mut self,
+ xs: &Tensor,
+ ca_src: Option<&Tensor>,
+ mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ if !self.norm_first {
+ candle::bail!("only norm_first = true is supported")
+ }
+ let norm1 = xs.apply(&self.norm1)?;
+ let xs = (xs
+ + self
+ .self_attn
+ .forward(&norm1, mask)?
+ .apply(&self.layer_scale_1.as_ref())?)?;
+
+ let xs = match (&self.cross_attn, ca_src) {
+ (Some((norm_cross, cross_attn)), Some(ca_src)) => {
+ let residual = &xs;
+ let xs = xs.apply(norm_cross)?;
+ (residual + cross_attn.forward(&xs, ca_src, None)?)?
+ }
+ _ => xs,
+ };
+
+ let xs = (&xs
+ + xs.apply(&self.norm2)?
+ .apply(&self.mlp)?
+ .apply(&self.layer_scale_2.as_ref()))?;
+ Ok(xs)
+ }
+
+ pub fn reset_kv_cache(&mut self) {
+ self.self_attn.reset_kv_cache()
+ }
+
+ pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) {
+ self.self_attn.set_kv_cache(kv_cache)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct StreamingTransformer {
+ layers: Vec<StreamingTransformerLayer>,
+ context: usize,
+ positional_embedding: PositionalEmbedding,
+ max_period: usize,
+}
+
+impl StreamingTransformer {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let vb_l = vb.pp("layers");
+ let rope = match cfg.positional_embedding {
+ PositionalEmbedding::Rope => {
+ let rope = RotaryEmbedding::new(
+ cfg.d_model / cfg.num_heads,
+ cfg.max_seq_len,
+ cfg.max_period as f32,
+ vb.device(),
+ )?;
+ Some(Arc::new(rope))
+ }
+ PositionalEmbedding::Sin | PositionalEmbedding::None => None,
+ };
+ let mut layers = Vec::with_capacity(cfg.num_layers);
+ for layer_idx in 0..cfg.num_layers {
+ let layer = StreamingTransformerLayer::new(&rope, cfg, vb_l.pp(layer_idx))?;
+ layers.push(layer)
+ }
+ Ok(Self {
+ layers,
+ context: cfg.context,
+ positional_embedding: cfg.positional_embedding,
+ max_period: cfg.max_period,
+ })
+ }
+
+ pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
+ self.forward_ca(xs, None)
+ }
+
+ pub fn forward_ca(&mut self, xs: &Tensor, ca_src: Option<&Tensor>) -> Result<Tensor> {
+ let (_b, t, c) = xs.dims3()?;
+ // We will extract at most "context" from the kv_cache.
+ // Note that the mask will discard the values that are before context.
+ let pos = self.layers[0]
+ .self_attn
+ .kv_cache
+ .k_cache()
+ .current_seq_len()
+ .min(self.context);
+ let mask = if t == 1 {
+ None
+ } else {
+ Some(get_mask(t, pos + t, self.context, xs.device())?)
+ };
+ let mut xs = match self.positional_embedding {
+ PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(),
+ PositionalEmbedding::Sin => {
+ let dev = xs.device();
+ let theta = self.max_period as f32;
+ let half_dim = c / 2;
+ let positions = Tensor::arange(pos as u32, (pos + t) as u32, dev)?
+ .unsqueeze(1)?
+ .to_dtype(DType::F32)?;
+ let inv_freq: Vec<_> = (0..half_dim)
+ .map(|i| 1f32 / theta.powf(i as f32 / (half_dim - 1) as f32))
+ .collect();
+ let inv_freq_len = inv_freq.len();
+ let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
+ let freqs = positions.broadcast_mul(&inv_freq)?;
+ let pos_emb =
+ Tensor::cat(&[freqs.cos()?, freqs.sin()?], D::Minus1)?.to_dtype(xs.dtype())?;
+ xs.broadcast_add(&pos_emb)?
+ }
+ };
+ for layer in self.layers.iter_mut() {
+ xs = layer.forward(&xs, ca_src, mask.as_ref())?;
+ }
+ Ok(xs)
+ }
+
+ pub fn copy_state(&mut self, from: &Self) -> Result<()> {
+ if self.layers.len() != from.layers.len() {
+ candle::bail!("cannot copy kv-caches as the transformers have different depths")
+ }
+ self.layers
+ .iter_mut()
+ .zip(from.layers.iter())
+ .for_each(|(v, w)| v.set_kv_cache(w.self_attn.kv_cache.clone()));
+ Ok(())
+ }
+}
+
+impl StreamingModule for StreamingTransformer {
+ fn reset_state(&mut self) {
+ self.layers.iter_mut().for_each(|v| v.reset_kv_cache())
+ }
+
+ fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
+ match xs.as_option() {
+ None => Ok(StreamTensor::empty()),
+ Some(xs) => Ok(StreamTensor::from_tensor(self.forward(xs)?)),
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct ProjectedTransformer {
+ transformer: StreamingTransformer,
+ input_proj: Option<Linear>,
+ output_projs: Vec<Option<Linear>>,
+ conv_layout: bool,
+ span: tracing::Span,
+}
+
+impl ProjectedTransformer {
+ pub fn new(
+ input_dim: usize,
+ output_dims: &[usize],
+ cfg: &Config,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let transformer = StreamingTransformer::new(cfg, vb.clone())?;
+ let input_proj = if input_dim == cfg.d_model {
+ None
+ } else {
+ let l = linear_no_bias(input_dim, cfg.d_model, vb.pp("input_proj"))?;
+ Some(l)
+ };
+ let mut output_projs = Vec::with_capacity(output_dims.len());
+ let vb_o = vb.pp("output_projs");
+ for (i, &output_dim) in output_dims.iter().enumerate() {
+ let output_proj = if output_dim == cfg.d_model {
+ None
+ } else {
+ let l = linear_no_bias(cfg.d_model, output_dim, vb_o.pp(i))?;
+ Some(l)
+ };
+ output_projs.push(output_proj)
+ }
+ Ok(Self {
+ transformer,
+ input_proj,
+ output_projs,
+ conv_layout: cfg.conv_layout,
+ span: tracing::span!(tracing::Level::TRACE, "proj-transformer"),
+ })
+ }
+
+ pub fn forward(&mut self, xs: &Tensor) -> Result<Vec<Tensor>> {
+ let _enter = self.span.enter();
+ let xs = if self.conv_layout {
+ xs.transpose(1, 2)?
+ } else {
+ xs.clone()
+ };
+ let xs = xs.apply(&self.input_proj.as_ref())?;
+ let xs = self.transformer.forward(&xs)?;
+ let mut ys = Vec::with_capacity(self.output_projs.len());
+ for output_proj in self.output_projs.iter() {
+ let ys_ = xs.apply(&output_proj.as_ref())?;
+ let ys_ = if self.conv_layout {
+ ys_.transpose(1, 2)?
+ } else {
+ ys_
+ };
+ ys.push(ys_)
+ }
+ Ok(ys)
+ }
+}
+
+impl StreamingModule for ProjectedTransformer {
+ fn reset_state(&mut self) {
+ self.transformer.reset_state()
+ }
+
+ fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
+ let xs = xs.apply(&|x: &Tensor| {
+ if self.conv_layout {
+ x.transpose(1, 2)
+ } else {
+ Ok(x.clone())
+ }
+ })?;
+ let xs = xs.apply(&self.input_proj.as_ref())?;
+ let xs = self.transformer.step(&xs)?;
+ let ys = xs.apply(&self.output_projs[0].as_ref())?;
+ ys.apply(&|y: &Tensor| {
+ if self.conv_layout {
+ y.transpose(1, 2)
+ } else {
+ Ok(y.clone())
+ }
+ })
+ }
+}
+
+#[cfg(feature = "flash-attn")]
+fn flash_attn(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ softmax_scale: f32,
+ causal: bool,
+) -> Result<Tensor> {
+ candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
+}
+
+#[cfg(not(feature = "flash-attn"))]
+fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
+ unimplemented!("compile with '--features flash-attn'")
+}
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index 9f7856ea..07672bcc 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -33,6 +33,7 @@ pub mod llava;
pub mod mamba;
pub mod marian;
pub mod metavoice;
+pub mod mimi;
pub mod mistral;
pub mod mixformer;
pub mod mixtral;