summaryrefslogtreecommitdiff
path: root/candle-examples/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-02-27 22:59:40 +0100
committerGitHub <noreply@github.com>2024-02-27 22:59:40 +0100
commit0c49e95dfb5f25a17340ff1b690a4b4f6cd0e2d1 (patch)
treec732811778ea6e15c558dcbe35153cd110eb5959 /candle-examples/src
parent205767f9ded3d531822d3702442a52b4a320f72e (diff)
downloadcandle-0c49e95dfb5f25a17340ff1b690a4b4f6cd0e2d1.tar.gz
candle-0c49e95dfb5f25a17340ff1b690a4b4f6cd0e2d1.tar.bz2
candle-0c49e95dfb5f25a17340ff1b690a4b4f6cd0e2d1.zip
Encodec model. (#1771)
* Encodec model. * Fixes. * Add the padding functions. * Get the LSTM bit to work. * Get the encodec model to generate some tokens (decoder only for now). * Minor tweak. * Minor tweak.
Diffstat (limited to 'candle-examples/src')
-rw-r--r--candle-examples/src/lib.rs1
-rw-r--r--candle-examples/src/wav.rs56
2 files changed, 57 insertions, 0 deletions
diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs
index d6dce4a3..7cb8eb01 100644
--- a/candle-examples/src/lib.rs
+++ b/candle-examples/src/lib.rs
@@ -1,6 +1,7 @@
pub mod coco_classes;
pub mod imagenet;
pub mod token_output_stream;
+pub mod wav;
use candle::utils::{cuda_is_available, metal_is_available};
use candle::{Device, Result, Tensor};
diff --git a/candle-examples/src/wav.rs b/candle-examples/src/wav.rs
new file mode 100644
index 00000000..df98aa14
--- /dev/null
+++ b/candle-examples/src/wav.rs
@@ -0,0 +1,56 @@
+use std::io::prelude::*;
+
+pub trait Sample {
+ fn to_i16(&self) -> i16;
+}
+
+impl Sample for f32 {
+ fn to_i16(&self) -> i16 {
+ (self.clamp(-1.0, 1.0) * 32767.0) as i16
+ }
+}
+
+impl Sample for f64 {
+ fn to_i16(&self) -> i16 {
+ (self.clamp(-1.0, 1.0) * 32767.0) as i16
+ }
+}
+
+impl Sample for i16 {
+ fn to_i16(&self) -> i16 {
+ *self
+ }
+}
+
+pub fn write_pcm_as_wav<W: Write, S: Sample>(
+ w: &mut W,
+ samples: &[S],
+ sample_rate: u32,
+) -> std::io::Result<()> {
+ let len = 12u32; // header
+ let len = len + 24u32; // fmt
+ let len = len + samples.len() as u32 * 2 + 8; // data
+ let n_channels = 1u16;
+ let bytes_per_second = sample_rate * 2 * n_channels as u32;
+ w.write_all(b"RIFF")?;
+ w.write_all(&(len - 8).to_le_bytes())?; // total length minus 8 bytes
+ w.write_all(b"WAVE")?;
+
+ // Format block
+ w.write_all(b"fmt ")?;
+ w.write_all(&16u32.to_le_bytes())?; // block len minus 8 bytes
+ w.write_all(&1u16.to_le_bytes())?; // PCM
+ w.write_all(&n_channels.to_le_bytes())?; // one channel
+ w.write_all(&sample_rate.to_le_bytes())?;
+ w.write_all(&bytes_per_second.to_le_bytes())?;
+ w.write_all(&2u16.to_le_bytes())?; // 2 bytes of data per sample
+ w.write_all(&16u16.to_le_bytes())?; // bits per sample
+
+ // Data block
+ w.write_all(b"data")?;
+ w.write_all(&(samples.len() as u32 * 2).to_le_bytes())?;
+ for sample in samples.iter() {
+ w.write_all(&sample.to_i16().to_le_bytes())?
+ }
+ Ok(())
+}