summaryrefslogtreecommitdiff
path: root/candle-examples
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples')
-rw-r--r--candle-examples/Cargo.toml9
-rw-r--r--candle-examples/examples/stable-diffusion/main.rs12
-rw-r--r--candle-examples/examples/stable-diffusion/utils.rs19
-rw-r--r--candle-examples/examples/whisper/main.rs37
-rw-r--r--candle-examples/examples/whisper/melfilters.bytes (renamed from candle-examples/examples/whisper/mel_filters.safetensors)bin64400 -> 64320 bytes
-rw-r--r--candle-examples/examples/whisper/model.rs98
-rw-r--r--candle-examples/src/lib.rs99
7 files changed, 210 insertions, 64 deletions
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml
index f3a4e325..54eb0be6 100644
--- a/candle-examples/Cargo.toml
+++ b/candle-examples/Cargo.toml
@@ -23,12 +23,13 @@ num-traits = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
+image = { workspace = true, optional = true }
[dev-dependencies]
anyhow = { workspace = true }
byteorder = { workspace = true }
+hf-hub = { workspace = true, features=["tokio"]}
clap = { workspace = true }
-hf-hub = { workspace = true }
memmap2 = { workspace = true }
rand = { workspace = true }
tokenizers = { workspace = true, features = ["onig"] }
@@ -36,6 +37,8 @@ tracing = { workspace = true }
tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true }
wav = { workspace = true }
+# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
+tokio = "1.29.1"
[build-dependencies]
anyhow = { workspace = true }
@@ -51,3 +54,7 @@ nccl = ["cuda", "cudarc/nccl", "dep:half"]
[[example]]
name = "llama_multiprocess"
required-features = ["cuda", "nccl", "flash-attn"]
+
+[[example]]
+name = "stable-diffusion"
+required-features = ["image"]
diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs
index 8bb3c56d..8ce0c234 100644
--- a/candle-examples/examples/stable-diffusion/main.rs
+++ b/candle-examples/examples/stable-diffusion/main.rs
@@ -245,10 +245,10 @@ fn run(args: Args) -> Result<()> {
if args.intermediary_images {
let image = vae.decode(&(&latents / 0.18215)?)?;
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
- let _image = (image * 255.)?.to_dtype(DType::U8);
- let _image_filename =
+ let image = (image * 255.)?.to_dtype(DType::U8)?;
+ let image_filename =
output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1));
- // TODO: save igame
+ crate::utils::save_image(&image, image_filename)?
}
}
@@ -260,9 +260,9 @@ fn run(args: Args) -> Result<()> {
let image = vae.decode(&(&latents / 0.18215)?)?;
// TODO: Add the clamping between 0 and 1.
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
- let _image = (image * 255.)?.to_dtype(DType::U8);
- let _image_filename = output_filename(&final_image, idx + 1, num_samples, None);
- // TODO: save image.
+ let image = (image * 255.)?.to_dtype(DType::U8)?;
+ let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
+ crate::utils::save_image(&image, image_filename)?
}
Ok(())
}
diff --git a/candle-examples/examples/stable-diffusion/utils.rs b/candle-examples/examples/stable-diffusion/utils.rs
index 0c95cfef..ef4dd956 100644
--- a/candle-examples/examples/stable-diffusion/utils.rs
+++ b/candle-examples/examples/stable-diffusion/utils.rs
@@ -10,3 +10,22 @@ pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
.collect::<Vec<_>>();
Tensor::from_vec(vs, steps, &Device::Cpu)
}
+
+/// Saves an image to disk using the image crate, this expects an input with shape
+/// (c, width, height).
+pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
+ let p = p.as_ref();
+ let (channel, width, height) = img.dims3()?;
+ if channel != 3 {
+ candle::bail!("save_image expects an input of shape (3, width, height)")
+ }
+ let img = img.transpose(0, 1)?.t()?.flatten_all()?;
+ let pixels = img.to_vec1::<u8>()?;
+ let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
+ match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
+ Some(image) => image,
+ None => candle::bail!("error saving image {p:?}"),
+ };
+ image.save(p).map_err(candle::Error::wrap)?;
+ Ok(())
+}
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index 82c45348..dfe7a27f 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -1,4 +1,3 @@
-#![allow(dead_code)]
// https://github.com/openai/whisper/blob/main/whisper/model.py
// TODO:
// - kv-cache support?
@@ -10,7 +9,7 @@
extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
-use candle::{safetensors::Load, DType, Device, Tensor};
+use candle::{DType, Device, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
@@ -31,9 +30,6 @@ const HOP_LENGTH: usize = 160;
const CHUNK_LENGTH: usize = 30;
const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
-const N_SAMPLES_PER_TOKEN: usize = HOP_LENGTH * 2; // the initial convolutions has stride 2
-const FRAMES_PER_SECOND: usize = SAMPLE_RATE / HOP_LENGTH; // 10ms per audio frame
-const TOKENS_PER_SECOND: usize = SAMPLE_RATE / N_SAMPLES_PER_TOKEN; // 20ms per audio token
const NO_SPEECH_THRESHOLD: f64 = 0.6;
const LOGPROB_THRESHOLD: f64 = -1.0;
@@ -44,7 +40,6 @@ const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
const SOT_TOKEN: u32 = 50257;
const EOT_TOKEN: u32 = 50256;
const NO_SPEECH_TOKEN: u32 = 50361;
-const NO_TIMESTAMP_TOKEN: u32 = 50362;
// From the _get_suppress_tokens function + 50362 (no timestamp)
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
const SUPPRESS_TOKENS: [u32; 91] = [
@@ -56,6 +51,7 @@ const SUPPRESS_TOKENS: [u32; 91] = [
47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362,
];
+#[allow(dead_code)]
#[derive(Debug, Clone)]
struct DecodingResult {
tokens: Vec<u32>,
@@ -66,6 +62,7 @@ struct DecodingResult {
compression_ratio: f64,
}
+#[allow(dead_code)]
#[derive(Debug, Clone)]
struct Segment {
start: f64,
@@ -244,16 +241,24 @@ struct Args {
#[arg(long, default_value_t = 299792458)]
seed: u64,
- /// The mel filters in safetensors format.
- #[arg(
- long,
- default_value = "candle-examples/examples/whisper/mel_filters.safetensors"
- )]
- filters: String,
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
}
fn main() -> Result<()> {
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
+
let args = Args::parse();
+ let _guard = if args.tracing {
+ println!("tracing...");
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
let device = candle_examples::device(args.cpu)?;
let default_model = "openai/whisper-tiny.en".to_string();
let path = std::path::PathBuf::from(default_model.clone());
@@ -301,11 +306,9 @@ fn main() -> Result<()> {
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
- let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? };
- let mel_filters = mel_filters.deserialize()?;
- let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
- println!("loaded mel filters {:?}", mel_filters.shape());
- let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
+ let mel_bytes = include_bytes!("melfilters.bytes");
+ let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
+ <byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
let mut input = std::fs::File::open(input)?;
let (header, data) = wav::read(&mut input)?;
diff --git a/candle-examples/examples/whisper/mel_filters.safetensors b/candle-examples/examples/whisper/melfilters.bytes
index 98f3af44..0874829e 100644
--- a/candle-examples/examples/whisper/mel_filters.safetensors
+++ b/candle-examples/examples/whisper/melfilters.bytes
Binary files differ
diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs
index 4d80c0c8..7015199d 100644
--- a/candle-examples/examples/whisper/model.rs
+++ b/candle-examples/examples/whisper/model.rs
@@ -1,8 +1,5 @@
-// We use anyhow rather than candle errors as it provides better support for getting the backtrace
-// back when using RUST_LIB_BACKTRACE=1.
-use anyhow::Result;
-use candle::{Device, Tensor};
-use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder};
+use candle::{Device, Result, Tensor};
+use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder};
use serde::Deserialize;
// The names in comments correspond to the original implementation:
@@ -22,6 +19,7 @@ pub struct Config {
}
impl Config {
+ #[allow(dead_code)]
pub fn tiny_en() -> Self {
Self {
num_mel_bins: 80,
@@ -42,16 +40,32 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
+//
+// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
+// model.
+#[derive(Debug)]
+pub struct Linear {
+ inner: candle_nn::Linear,
+ span: tracing::Span,
+}
+
+impl Linear {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(x)
+ }
+}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
- let weight = vb.get((size2, size1), "weight")?;
- let bias = vb.get(size2, "bias")?;
- Ok(Linear::new(weight, Some(bias)))
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ let inner = candle_nn::linear(size1, size2, vb)?;
+ Ok(Linear { inner, span })
}
fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
- let weight = vb.get((size2, size1), "weight")?;
- Ok(Linear::new(weight, None))
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
+ Ok(Linear { inner, span })
}
fn conv1d(
@@ -66,32 +80,6 @@ fn conv1d(
Ok(Conv1d::new(weight, Some(bias), config))
}
-fn conv1d_no_bias(
- in_channels: usize,
- out_channels: usize,
- kernel_size: usize,
- config: Conv1dConfig,
- vb: VarBuilder,
-) -> Result<Conv1d> {
- let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
- Ok(Conv1d::new(weight, None, config))
-}
-
-struct Dropout {
- pr: f64,
-}
-
-impl Dropout {
- fn new(pr: f64) -> Self {
- Self { pr }
- }
-
- fn forward(&self, x: &Tensor) -> Result<Tensor> {
- // TODO
- Ok(x.clone())
- }
-}
-
fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> {
let weight = vb.get(size, "weight")?;
let bias = vb.get(size, "bias")?;
@@ -105,10 +93,12 @@ struct MultiHeadAttention {
value: Linear,
out: Linear,
n_head: usize,
+ span: tracing::Span,
}
impl MultiHeadAttention {
fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn");
let query = linear(n_state, n_state, vb.pp("q_proj"))?;
let value = linear(n_state, n_state, vb.pp("v_proj"))?;
let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
@@ -119,10 +109,12 @@ impl MultiHeadAttention {
value,
out,
n_head,
+ span,
})
}
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
+ let _enter = self.span.enter();
let q = self.query.forward(x)?;
let k = self.key.forward(xa.unwrap_or(x))?;
let v = self.value.forward(xa.unwrap_or(x))?;
@@ -134,7 +126,7 @@ impl MultiHeadAttention {
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
let (n_batch, n_ctx, n_state) = x.dims3()?;
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
- Ok(x.reshape(target_dims)?.transpose(1, 2)?)
+ x.reshape(target_dims)?.transpose(1, 2)
}
fn qkv_attention(
@@ -168,10 +160,12 @@ struct ResidualAttentionBlock {
mlp_linear1: Linear,
mlp_linear2: Linear,
mlp_ln: LayerNorm,
+ span: tracing::Span,
}
impl ResidualAttentionBlock {
fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "residual-attn");
let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
let cross_attn = if ca {
@@ -192,10 +186,12 @@ impl ResidualAttentionBlock {
mlp_linear1,
mlp_linear2,
mlp_ln,
+ span,
})
}
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
+ let _enter = self.span.enter();
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?;
let mut x = (x + attn)?;
if let Some((attn, ln)) = &self.cross_attn {
@@ -207,7 +203,7 @@ impl ResidualAttentionBlock {
.forward(&self.mlp_ln.forward(&x)?)?
.gelu()?,
)?;
- Ok((x + mlp)?)
+ x + mlp
}
}
@@ -234,10 +230,16 @@ pub struct AudioEncoder {
positional_embedding: Tensor,
blocks: Vec<ResidualAttentionBlock>,
ln_post: LayerNorm,
+ span: tracing::Span,
+ conv1_span: tracing::Span,
+ conv2_span: tracing::Span,
}
impl AudioEncoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "audio-encoder");
+ let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1");
+ let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2");
let n_state = cfg.d_model;
let n_head = cfg.encoder_attention_heads;
let n_ctx = cfg.max_source_positions;
@@ -264,11 +266,22 @@ impl AudioEncoder {
positional_embedding,
blocks,
ln_post,
+ conv1_span,
+ conv2_span,
+ span,
})
}
+
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let x = self.conv1.forward(x)?.gelu()?;
- let x = self.conv2.forward(&x)?.gelu()?;
+ let _enter = self.span.enter();
+ let x = {
+ let _enter = self.conv1_span.enter();
+ self.conv1.forward(x)?.gelu()?
+ };
+ let x = {
+ let _enter = self.conv2_span.enter();
+ self.conv2.forward(&x)?.gelu()?
+ };
let x = x.transpose(1, 2)?;
let (_bsize, seq_len, _hidden) = x.dims3()?;
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
@@ -288,10 +301,12 @@ pub struct TextDecoder {
blocks: Vec<ResidualAttentionBlock>,
ln: LayerNorm,
mask: Tensor,
+ span: tracing::Span,
}
impl TextDecoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "text-decoder");
let n_state = cfg.d_model;
let n_head = cfg.decoder_attention_heads;
let n_ctx = cfg.max_target_positions;
@@ -314,10 +329,12 @@ impl TextDecoder {
blocks,
ln,
mask,
+ span,
})
}
pub fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let x_dims = x.dims();
let last = x_dims[x_dims.len() - 1];
let token_embedding = self.token_embedding.forward(x)?;
@@ -354,6 +371,7 @@ impl Whisper {
})
}
+ #[allow(dead_code)]
pub fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result<Tensor> {
let enc = self.encoder.forward(mel)?;
let dec = self.decoder.forward(tokens, &enc)?;
diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs
index 285aee04..2b6009b4 100644
--- a/candle-examples/src/lib.rs
+++ b/candle-examples/src/lib.rs
@@ -11,3 +11,102 @@ pub fn device(cpu: bool) -> Result<Device> {
Ok(device)
}
}
+
+#[cfg(test)]
+mod tests {
+ // NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
+ #[rustfmt::skip]
+ #[tokio::test]
+ async fn book_hub_1() {
+// ANCHOR: book_hub_1
+use candle::Device;
+use hf_hub::api::tokio::Api;
+
+let api = Api::new().unwrap();
+let repo = api.model("bert-base-uncased".to_string());
+
+let weights_filename = repo.get("model.safetensors").await.unwrap();
+
+let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
+// ANCHOR_END: book_hub_1
+ assert_eq!(weights.len(), 206);
+ }
+
+ #[rustfmt::skip]
+ #[test]
+ fn book_hub_2() {
+// ANCHOR: book_hub_2
+use candle::Device;
+use hf_hub::api::sync::Api;
+use memmap2::Mmap;
+use std::fs;
+
+let api = Api::new().unwrap();
+let repo = api.model("bert-base-uncased".to_string());
+let weights_filename = repo.get("model.safetensors").unwrap();
+
+let file = fs::File::open(weights_filename).unwrap();
+let mmap = unsafe { Mmap::map(&file).unwrap() };
+let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
+// ANCHOR_END: book_hub_2
+ assert_eq!(weights.len(), 206);
+ }
+
+ #[rustfmt::skip]
+ #[test]
+ fn book_hub_3() {
+// ANCHOR: book_hub_3
+use candle::{DType, Device, Tensor};
+use hf_hub::api::sync::Api;
+use memmap2::Mmap;
+use safetensors::slice::IndexOp;
+use safetensors::SafeTensors;
+use std::fs;
+
+let api = Api::new().unwrap();
+let repo = api.model("bert-base-uncased".to_string());
+let weights_filename = repo.get("model.safetensors").unwrap();
+
+let file = fs::File::open(weights_filename).unwrap();
+let mmap = unsafe { Mmap::map(&file).unwrap() };
+
+// Use safetensors directly
+let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
+let view = tensors
+ .tensor("bert.encoder.layer.0.attention.self.query.weight")
+ .unwrap();
+
+// We're going to load shard with rank 1, within a world_size of 4
+// We're going to split along dimension 0 doing VIEW[start..stop, :]
+let rank = 1;
+let world_size = 4;
+let dim = 0;
+let dtype = view.dtype();
+let mut tp_shape = view.shape().to_vec();
+let size = tp_shape[0];
+
+if size % world_size != 0 {
+ panic!("The dimension is not divisble by `world_size`");
+}
+let block_size = size / world_size;
+let start = rank * block_size;
+let stop = (rank + 1) * block_size;
+
+// Everything is expressed in tensor dimension
+// bytes offsets is handled automatically for safetensors.
+
+let iterator = view.slice(start..stop).unwrap();
+
+tp_shape[dim] = block_size;
+
+// Convert safetensors Dtype to candle DType
+let dtype: DType = dtype.try_into().unwrap();
+
+// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
+let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
+let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
+// ANCHOR_END: book_hub_3
+ assert_eq!(view.shape(), &[768, 768]);
+ assert_eq!(tp_tensor.dims(), &[192, 768]);
+ }
+}