diff options
Diffstat (limited to 'candle-examples/examples/whisper/main.rs')
-rw-r--r-- | candle-examples/examples/whisper/main.rs | 17 |
1 files changed, 4 insertions, 13 deletions
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 82c45348..c9e9ccc6 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -10,7 +10,7 @@ extern crate intel_mkl_src; use anyhow::{Error as E, Result}; -use candle::{safetensors::Load, DType, Device, Tensor}; +use candle::{DType, Device, Tensor}; use candle_nn::{ops::softmax, VarBuilder}; use clap::Parser; use hf_hub::{api::sync::Api, Repo, RepoType}; @@ -243,13 +243,6 @@ struct Args { /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, - - /// The mel filters in safetensors format. - #[arg( - long, - default_value = "candle-examples/examples/whisper/mel_filters.safetensors" - )] - filters: String, } fn main() -> Result<()> { @@ -301,11 +294,9 @@ fn main() -> Result<()> { }; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; - let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? }; - let mel_filters = mel_filters.deserialize()?; - let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?; - println!("loaded mel filters {:?}", mel_filters.shape()); - let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?; + let mel_bytes = include_bytes!("melfilters.bytes"); + let mut mel_filters = vec![0f32; mel_bytes.len() / 4]; + <byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters); let mut input = std::fs::File::open(input)?; let (header, data) = wav::read(&mut input)?; |