summaryrefslogtreecommitdiff
path: root/candle-examples/examples/whisper/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/whisper/main.rs')
-rw-r--r--candle-examples/examples/whisper/main.rs17
1 files changed, 4 insertions, 13 deletions
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index 82c45348..c9e9ccc6 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -10,7 +10,7 @@
extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
-use candle::{safetensors::Load, DType, Device, Tensor};
+use candle::{DType, Device, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
@@ -243,13 +243,6 @@ struct Args {
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
-
- /// The mel filters in safetensors format.
- #[arg(
- long,
- default_value = "candle-examples/examples/whisper/mel_filters.safetensors"
- )]
- filters: String,
}
fn main() -> Result<()> {
@@ -301,11 +294,9 @@ fn main() -> Result<()> {
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
- let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? };
- let mel_filters = mel_filters.deserialize()?;
- let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
- println!("loaded mel filters {:?}", mel_filters.shape());
- let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
+ let mel_bytes = include_bytes!("melfilters.bytes");
+ let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
+ <byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
let mut input = std::fs::File::open(input)?;
let (header, data) = wav::read(&mut input)?;