summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-transformers/src/models/metavoice.rs130
-rw-r--r--candle-transformers/src/models/whisper/audio.rs2
2 files changed, 109 insertions, 23 deletions
diff --git a/candle-transformers/src/models/metavoice.rs b/candle-transformers/src/models/metavoice.rs
index 0ab19041..993f73ef 100644
--- a/candle-transformers/src/models/metavoice.rs
+++ b/candle-transformers/src/models/metavoice.rs
@@ -1,4 +1,4 @@
-use candle::{DType, Error as E, IndexOp, Module, Result, Tensor, D};
+use candle::{DType, Device, Error as E, IndexOp, Module, Result, Tensor, D};
use candle_nn::{embedding, linear_b, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
// Equivalent to torch.repeat_interleave
@@ -13,22 +13,41 @@ pub mod speaker_encoder {
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
- pub mel_window_step: usize,
- pub mel_n_channels: usize,
pub sampling_rate: usize,
pub partial_n_frames: usize,
pub model_hidden_size: usize,
pub model_embedding_size: usize,
pub model_num_layers: usize,
+ pub mel_window_length: usize,
+ pub mel_window_step: usize,
+ pub mel_n_channels: usize,
+ }
+
+ impl Config {
+ pub fn cfg() -> Self {
+ Self {
+ sampling_rate: 16_000,
+ partial_n_frames: 160,
+ model_hidden_size: 256,
+ model_embedding_size: 256,
+ model_num_layers: 3,
+ mel_window_length: 25,
+ mel_window_step: 10,
+ mel_n_channels: 40,
+ }
+ }
}
pub struct Model {
lstms: Vec<candle_nn::LSTM>,
linear: Linear,
+ cfg: Config,
}
+ type Slice = (usize, usize);
+
impl Model {
- pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ pub fn new(cfg: Config, vb: VarBuilder) -> Result<Self> {
let mut lstms = Vec::with_capacity(cfg.model_num_layers);
let vb_l = vb.pp("lstm");
for layer_idx in 0..cfg.model_num_layers {
@@ -50,36 +69,103 @@ pub mod speaker_encoder {
true,
vb.pp("linear"),
)?;
- Ok(Self { lstms, linear })
+ Ok(Self { lstms, linear, cfg })
}
fn compute_partial_slices(
- _n_samples: usize,
- _rate: f64,
- _min_coverage: f64,
- ) -> Result<(Tensor, Tensor)> {
- todo!()
- }
-
- pub fn embed_utterance(&self, wav: &[f32], rate: f64, min_coverage: f64) -> Result<Tensor> {
- let (_wav_slices, _mel_slices) =
- Self::compute_partial_slices(wav.len(), rate, min_coverage)?;
- todo!()
+ &self,
+ n_samples: usize,
+ rate: f64,
+ min_coverage: f64,
+ ) -> (Vec<Slice>, Vec<Slice>) {
+ let c = &self.cfg;
+ // Compute how many frames separate two partial utterances
+ let samples_per_frame = c.sampling_rate * c.mel_window_step / 1000;
+ let n_frames = n_samples / samples_per_frame + 1;
+ let frame_step =
+ (c.sampling_rate as f64 / rate / samples_per_frame as f64).round() as usize;
+ let steps = (n_frames + frame_step).saturating_sub(c.partial_n_frames) + 1;
+ // Compute the slices.
+ let mut wav_slices = vec![];
+ let mut mel_slices = vec![];
+ for i in (0..steps).step_by(frame_step) {
+ let mel_range = (i, i + c.partial_n_frames);
+ let wav_range = (
+ i * samples_per_frame,
+ (i + c.partial_n_frames) * samples_per_frame,
+ );
+ mel_slices.push(mel_range);
+ wav_slices.push(wav_range);
+ }
+ // Evaluate whether extra padding is warranted or not.
+ let last_wav_range = match wav_slices.last() {
+ None => return (wav_slices, mel_slices),
+ Some(l) => *l,
+ };
+ let coverage = (n_samples - last_wav_range.0) as f64
+ / (last_wav_range.1 - last_wav_range.0) as f64;
+ if coverage > min_coverage && mel_slices.len() > 1 {
+ mel_slices.pop();
+ wav_slices.pop();
+ }
+ (wav_slices, mel_slices)
+ }
+
+ pub fn embed_utterance(
+ &self,
+ wav: &[f32],
+ mel_filters: &[f32],
+ rate: f64,
+ min_c: f64,
+ device: &Device,
+ ) -> Result<Tensor> {
+ let (wav_slices, mel_slices) = self.compute_partial_slices(wav.len(), rate, min_c);
+ let max_wave_length = match wav_slices.last() {
+ Some(v) => v.1,
+ None => candle::bail!("empty wav slices"),
+ };
+ let wav = if max_wave_length > wav.len() {
+ let mut wav = wav.to_vec();
+ wav.resize(max_wave_length - wav.len(), 0.0);
+ std::borrow::Cow::Owned(wav)
+ } else {
+ std::borrow::Cow::Borrowed(wav)
+ };
+ let mel = crate::models::whisper::audio::log_mel_spectrogram_(
+ wav.as_ref(),
+ mel_filters,
+ /* fft_size */ self.cfg.mel_window_length,
+ /* fft_step */ self.cfg.mel_window_step,
+ self.cfg.mel_n_channels,
+ false,
+ );
+ let mels = mel_slices
+ .iter()
+ .flat_map(|s| [mel[s.0], mel[s.1]])
+ .collect::<Vec<_>>();
+ let mels = Tensor::from_vec(mels, (mel_slices.len(), 2), device)?;
+ let partial_embeds = self.forward(&mels)?;
+ let raw_embed = partial_embeds.mean(0)?;
+ let norm = raw_embed.sqr()?.sum_all()?.sqrt()?;
+ raw_embed.broadcast_div(&norm)
}
}
impl Module for Model {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
use candle_nn::RNN;
+
+ // This is different from the Python transformers version as candle LSTM is batch first.
+ let xs = xs.t()?;
let mut xs = xs.clone();
- for lstm in self.lstms.iter() {
- let res = lstm.seq(&xs)?;
- let res: Vec<_> = res.into_iter().map(|s| s.h().clone()).collect();
- xs = Tensor::stack(&res, 1)?;
+ for layer in self.lstms.iter() {
+ let states = layer.seq(&xs)?;
+ xs = layer.states_to_tensor(&states)?;
}
+ let xs = xs.t()?;
let embeds_raw = xs.apply(&self.linear)?.relu()?;
- // TODO: normalize.
- Ok(embeds_raw)
+ let norm = embeds_raw.sqr()?.sum_keepdim(1)?.sqrt()?;
+ embeds_raw.broadcast_div(&norm)
}
}
}
diff --git a/candle-transformers/src/models/whisper/audio.rs b/candle-transformers/src/models/whisper/audio.rs
index eb795f18..35f9f3df 100644
--- a/candle-transformers/src/models/whisper/audio.rs
+++ b/candle-transformers/src/models/whisper/audio.rs
@@ -167,7 +167,7 @@ fn log_mel_spectrogram_w<T: Float>(
mel
}
-fn log_mel_spectrogram_<T: Float>(
+pub fn log_mel_spectrogram_<T: Float>(
samples: &[T],
filters: &[T],
fft_size: usize,