summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/whisper/audio.rs23
-rw-r--r--candle-examples/examples/whisper/main.rs6
2 files changed, 13 insertions, 16 deletions
diff --git a/candle-examples/examples/whisper/audio.rs b/candle-examples/examples/whisper/audio.rs
index d50b7923..d095e239 100644
--- a/candle-examples/examples/whisper/audio.rs
+++ b/candle-examples/examples/whisper/audio.rs
@@ -148,7 +148,7 @@ fn log_mel_spectrogram_w<T: Float>(
mel
}
-fn log_mel_spectrogram_<T: Float>(
+fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
samples: &[T],
filters: &[T],
fft_size: usize,
@@ -200,20 +200,17 @@ fn log_mel_spectrogram_<T: Float>(
mel
}
-pub fn pcm_to_mel<T: Float>(
+pub fn pcm_to_mel<T: Float + std::fmt::Display>(
samples: &[T],
filters: &[T],
- n_mel: usize,
- n_fft: usize,
) -> anyhow::Result<Vec<T>> {
- if filters.len() != n_mel * n_fft {
- anyhow::bail!(
- "unexpected filter length {} (n_mel: {}, n_fft: {})",
- filters.len(),
- n_mel,
- n_fft
- )
- }
- let mel = log_mel_spectrogram_(samples, filters, n_fft, super::HOP_LENGTH, n_mel, false);
+ let mel = log_mel_spectrogram_(
+ samples,
+ filters,
+ super::N_FFT,
+ super::HOP_LENGTH,
+ super::N_MELS,
+ false,
+ );
Ok(mel)
}
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index 6e15fa8a..6ea3e536 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -201,7 +201,6 @@ fn main() -> Result<()> {
let mel_filters = mel_filters.deserialize()?;
let mel_filters = mel_filters.tensor("mel_80", &device)?;
println!("loaded mel filters {:?}", mel_filters.shape());
- let (n_mel, n_fft) = mel_filters.shape().r2()?;
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
let mut input = std::fs::File::open(args.input)?;
@@ -215,9 +214,10 @@ fn main() -> Result<()> {
.iter()
.map(|v| *v as f32 / 32768.)
.collect();
- let mel = audio::pcm_to_mel(&pcm_data, &mel_filters, n_mel, n_fft)?;
+ println!("pcm data loaded {}", pcm_data.len());
+ let mel = audio::pcm_to_mel(&pcm_data, &mel_filters)?;
let mel_len = mel.len();
- let mel = Tensor::from_vec(mel, (1, n_mel, mel_len / n_mel), &device)?;
+ let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?;
println!("loaded mel: {:?}", mel.dims());
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };