diff options
Diffstat (limited to 'candle-examples')
-rw-r--r-- | candle-examples/Cargo.toml | 9 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/main.rs | 12 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/utils.rs | 19 | ||||
-rw-r--r-- | candle-examples/examples/whisper/main.rs | 37 | ||||
-rw-r--r-- | candle-examples/examples/whisper/melfilters.bytes (renamed from candle-examples/examples/whisper/mel_filters.safetensors) | bin | 64400 -> 64320 bytes | |||
-rw-r--r-- | candle-examples/examples/whisper/model.rs | 98 | ||||
-rw-r--r-- | candle-examples/src/lib.rs | 99 |
7 files changed, 210 insertions, 64 deletions
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index f3a4e325..54eb0be6 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -23,12 +23,13 @@ num-traits = { workspace = true } intel-mkl-src = { workspace = true, optional = true } cudarc = { workspace = true, optional = true } half = { workspace = true, optional = true } +image = { workspace = true, optional = true } [dev-dependencies] anyhow = { workspace = true } byteorder = { workspace = true } +hf-hub = { workspace = true, features=["tokio"]} clap = { workspace = true } -hf-hub = { workspace = true } memmap2 = { workspace = true } rand = { workspace = true } tokenizers = { workspace = true, features = ["onig"] } @@ -36,6 +37,8 @@ tracing = { workspace = true } tracing-chrome = { workspace = true } tracing-subscriber = { workspace = true } wav = { workspace = true } +# Necessary to disambiguate with tokio in wasm examples which are 1.28.1 +tokio = "1.29.1" [build-dependencies] anyhow = { workspace = true } @@ -51,3 +54,7 @@ nccl = ["cuda", "cudarc/nccl", "dep:half"] [[example]] name = "llama_multiprocess" required-features = ["cuda", "nccl", "flash-attn"] + +[[example]] +name = "stable-diffusion" +required-features = ["image"] diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs index 8bb3c56d..8ce0c234 100644 --- a/candle-examples/examples/stable-diffusion/main.rs +++ b/candle-examples/examples/stable-diffusion/main.rs @@ -245,10 +245,10 @@ fn run(args: Args) -> Result<()> { if args.intermediary_images { let image = vae.decode(&(&latents / 0.18215)?)?; let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; - let _image = (image * 255.)?.to_dtype(DType::U8); - let _image_filename = + let image = (image * 255.)?.to_dtype(DType::U8)?; + let image_filename = output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1)); - // TODO: save igame + crate::utils::save_image(&image, image_filename)? } } @@ -260,9 +260,9 @@ fn run(args: Args) -> Result<()> { let image = vae.decode(&(&latents / 0.18215)?)?; // TODO: Add the clamping between 0 and 1. let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; - let _image = (image * 255.)?.to_dtype(DType::U8); - let _image_filename = output_filename(&final_image, idx + 1, num_samples, None); - // TODO: save image. + let image = (image * 255.)?.to_dtype(DType::U8)?; + let image_filename = output_filename(&final_image, idx + 1, num_samples, None); + crate::utils::save_image(&image, image_filename)? } Ok(()) } diff --git a/candle-examples/examples/stable-diffusion/utils.rs b/candle-examples/examples/stable-diffusion/utils.rs index 0c95cfef..ef4dd956 100644 --- a/candle-examples/examples/stable-diffusion/utils.rs +++ b/candle-examples/examples/stable-diffusion/utils.rs @@ -10,3 +10,22 @@ pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> { .collect::<Vec<_>>(); Tensor::from_vec(vs, steps, &Device::Cpu) } + +/// Saves an image to disk using the image crate, this expects an input with shape +/// (c, width, height). +pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> { + let p = p.as_ref(); + let (channel, width, height) = img.dims3()?; + if channel != 3 { + candle::bail!("save_image expects an input of shape (3, width, height)") + } + let img = img.transpose(0, 1)?.t()?.flatten_all()?; + let pixels = img.to_vec1::<u8>()?; + let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> = + match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) { + Some(image) => image, + None => candle::bail!("error saving image {p:?}"), + }; + image.save(p).map_err(candle::Error::wrap)?; + Ok(()) +} diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 82c45348..dfe7a27f 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -1,4 +1,3 @@ -#![allow(dead_code)] // https://github.com/openai/whisper/blob/main/whisper/model.py // TODO: // - kv-cache support? @@ -10,7 +9,7 @@ extern crate intel_mkl_src; use anyhow::{Error as E, Result}; -use candle::{safetensors::Load, DType, Device, Tensor}; +use candle::{DType, Device, Tensor}; use candle_nn::{ops::softmax, VarBuilder}; use clap::Parser; use hf_hub::{api::sync::Api, Repo, RepoType}; @@ -31,9 +30,6 @@ const HOP_LENGTH: usize = 160; const CHUNK_LENGTH: usize = 30; const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input -const N_SAMPLES_PER_TOKEN: usize = HOP_LENGTH * 2; // the initial convolutions has stride 2 -const FRAMES_PER_SECOND: usize = SAMPLE_RATE / HOP_LENGTH; // 10ms per audio frame -const TOKENS_PER_SECOND: usize = SAMPLE_RATE / N_SAMPLES_PER_TOKEN; // 20ms per audio token const NO_SPEECH_THRESHOLD: f64 = 0.6; const LOGPROB_THRESHOLD: f64 = -1.0; @@ -44,7 +40,6 @@ const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4; const SOT_TOKEN: u32 = 50257; const EOT_TOKEN: u32 = 50256; const NO_SPEECH_TOKEN: u32 = 50361; -const NO_TIMESTAMP_TOKEN: u32 = 50362; // From the _get_suppress_tokens function + 50362 (no timestamp) // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605 const SUPPRESS_TOKENS: [u32; 91] = [ @@ -56,6 +51,7 @@ const SUPPRESS_TOKENS: [u32; 91] = [ 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362, ]; +#[allow(dead_code)] #[derive(Debug, Clone)] struct DecodingResult { tokens: Vec<u32>, @@ -66,6 +62,7 @@ struct DecodingResult { compression_ratio: f64, } +#[allow(dead_code)] #[derive(Debug, Clone)] struct Segment { start: f64, @@ -244,16 +241,24 @@ struct Args { #[arg(long, default_value_t = 299792458)] seed: u64, - /// The mel filters in safetensors format. - #[arg( - long, - default_value = "candle-examples/examples/whisper/mel_filters.safetensors" - )] - filters: String, + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, } fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + let args = Args::parse(); + let _guard = if args.tracing { + println!("tracing..."); + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; let device = candle_examples::device(args.cpu)?; let default_model = "openai/whisper-tiny.en".to_string(); let path = std::path::PathBuf::from(default_model.clone()); @@ -301,11 +306,9 @@ fn main() -> Result<()> { }; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; - let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? }; - let mel_filters = mel_filters.deserialize()?; - let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?; - println!("loaded mel filters {:?}", mel_filters.shape()); - let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?; + let mel_bytes = include_bytes!("melfilters.bytes"); + let mut mel_filters = vec![0f32; mel_bytes.len() / 4]; + <byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters); let mut input = std::fs::File::open(input)?; let (header, data) = wav::read(&mut input)?; diff --git a/candle-examples/examples/whisper/mel_filters.safetensors b/candle-examples/examples/whisper/melfilters.bytes Binary files differindex 98f3af44..0874829e 100644 --- a/candle-examples/examples/whisper/mel_filters.safetensors +++ b/candle-examples/examples/whisper/melfilters.bytes diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs index 4d80c0c8..7015199d 100644 --- a/candle-examples/examples/whisper/model.rs +++ b/candle-examples/examples/whisper/model.rs @@ -1,8 +1,5 @@ -// We use anyhow rather than candle errors as it provides better support for getting the backtrace -// back when using RUST_LIB_BACKTRACE=1. -use anyhow::Result; -use candle::{Device, Tensor}; -use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder}; +use candle::{Device, Result, Tensor}; +use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder}; use serde::Deserialize; // The names in comments correspond to the original implementation: @@ -22,6 +19,7 @@ pub struct Config { } impl Config { + #[allow(dead_code)] pub fn tiny_en() -> Self { Self { num_mel_bins: 80, @@ -42,16 +40,32 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em let embeddings = vb.get((vocab_size, hidden_size), "weight")?; Ok(Embedding::new(embeddings, hidden_size)) } +// +// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting +// model. +#[derive(Debug)] +pub struct Linear { + inner: candle_nn::Linear, + span: tracing::Span, +} + +impl Linear { + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { - let weight = vb.get((size2, size1), "weight")?; - let bias = vb.get(size2, "bias")?; - Ok(Linear::new(weight, Some(bias))) + let span = tracing::span!(tracing::Level::TRACE, "linear"); + let inner = candle_nn::linear(size1, size2, vb)?; + Ok(Linear { inner, span }) } fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { - let weight = vb.get((size2, size1), "weight")?; - Ok(Linear::new(weight, None)) + let span = tracing::span!(tracing::Level::TRACE, "linear"); + let inner = candle_nn::linear_no_bias(size1, size2, vb)?; + Ok(Linear { inner, span }) } fn conv1d( @@ -66,32 +80,6 @@ fn conv1d( Ok(Conv1d::new(weight, Some(bias), config)) } -fn conv1d_no_bias( - in_channels: usize, - out_channels: usize, - kernel_size: usize, - config: Conv1dConfig, - vb: VarBuilder, -) -> Result<Conv1d> { - let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?; - Ok(Conv1d::new(weight, None, config)) -} - -struct Dropout { - pr: f64, -} - -impl Dropout { - fn new(pr: f64) -> Self { - Self { pr } - } - - fn forward(&self, x: &Tensor) -> Result<Tensor> { - // TODO - Ok(x.clone()) - } -} - fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> { let weight = vb.get(size, "weight")?; let bias = vb.get(size, "bias")?; @@ -105,10 +93,12 @@ struct MultiHeadAttention { value: Linear, out: Linear, n_head: usize, + span: tracing::Span, } impl MultiHeadAttention { fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn"); let query = linear(n_state, n_state, vb.pp("q_proj"))?; let value = linear(n_state, n_state, vb.pp("v_proj"))?; let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?; @@ -119,10 +109,12 @@ impl MultiHeadAttention { value, out, n_head, + span, }) } fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> { + let _enter = self.span.enter(); let q = self.query.forward(x)?; let k = self.key.forward(xa.unwrap_or(x))?; let v = self.value.forward(xa.unwrap_or(x))?; @@ -134,7 +126,7 @@ impl MultiHeadAttention { fn reshape_head(&self, x: &Tensor) -> Result<Tensor> { let (n_batch, n_ctx, n_state) = x.dims3()?; let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head]; - Ok(x.reshape(target_dims)?.transpose(1, 2)?) + x.reshape(target_dims)?.transpose(1, 2) } fn qkv_attention( @@ -168,10 +160,12 @@ struct ResidualAttentionBlock { mlp_linear1: Linear, mlp_linear2: Linear, mlp_ln: LayerNorm, + span: tracing::Span, } impl ResidualAttentionBlock { fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "residual-attn"); let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?; let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?; let cross_attn = if ca { @@ -192,10 +186,12 @@ impl ResidualAttentionBlock { mlp_linear1, mlp_linear2, mlp_ln, + span, }) } fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> { + let _enter = self.span.enter(); let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?; let mut x = (x + attn)?; if let Some((attn, ln)) = &self.cross_attn { @@ -207,7 +203,7 @@ impl ResidualAttentionBlock { .forward(&self.mlp_ln.forward(&x)?)? .gelu()?, )?; - Ok((x + mlp)?) + x + mlp } } @@ -234,10 +230,16 @@ pub struct AudioEncoder { positional_embedding: Tensor, blocks: Vec<ResidualAttentionBlock>, ln_post: LayerNorm, + span: tracing::Span, + conv1_span: tracing::Span, + conv2_span: tracing::Span, } impl AudioEncoder { fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "audio-encoder"); + let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1"); + let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2"); let n_state = cfg.d_model; let n_head = cfg.encoder_attention_heads; let n_ctx = cfg.max_source_positions; @@ -264,11 +266,22 @@ impl AudioEncoder { positional_embedding, blocks, ln_post, + conv1_span, + conv2_span, + span, }) } + pub fn forward(&self, x: &Tensor) -> Result<Tensor> { - let x = self.conv1.forward(x)?.gelu()?; - let x = self.conv2.forward(&x)?.gelu()?; + let _enter = self.span.enter(); + let x = { + let _enter = self.conv1_span.enter(); + self.conv1.forward(x)?.gelu()? + }; + let x = { + let _enter = self.conv2_span.enter(); + self.conv2.forward(&x)?.gelu()? + }; let x = x.transpose(1, 2)?; let (_bsize, seq_len, _hidden) = x.dims3()?; let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?; @@ -288,10 +301,12 @@ pub struct TextDecoder { blocks: Vec<ResidualAttentionBlock>, ln: LayerNorm, mask: Tensor, + span: tracing::Span, } impl TextDecoder { fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "text-decoder"); let n_state = cfg.d_model; let n_head = cfg.decoder_attention_heads; let n_ctx = cfg.max_target_positions; @@ -314,10 +329,12 @@ impl TextDecoder { blocks, ln, mask, + span, }) } pub fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); let x_dims = x.dims(); let last = x_dims[x_dims.len() - 1]; let token_embedding = self.token_embedding.forward(x)?; @@ -354,6 +371,7 @@ impl Whisper { }) } + #[allow(dead_code)] pub fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result<Tensor> { let enc = self.encoder.forward(mel)?; let dec = self.decoder.forward(tokens, &enc)?; diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs index 285aee04..2b6009b4 100644 --- a/candle-examples/src/lib.rs +++ b/candle-examples/src/lib.rs @@ -11,3 +11,102 @@ pub fn device(cpu: bool) -> Result<Device> { Ok(device) } } + +#[cfg(test)] +mod tests { + // NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856 + #[rustfmt::skip] + #[tokio::test] + async fn book_hub_1() { +// ANCHOR: book_hub_1 +use candle::Device; +use hf_hub::api::tokio::Api; + +let api = Api::new().unwrap(); +let repo = api.model("bert-base-uncased".to_string()); + +let weights_filename = repo.get("model.safetensors").await.unwrap(); + +let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap(); +// ANCHOR_END: book_hub_1 + assert_eq!(weights.len(), 206); + } + + #[rustfmt::skip] + #[test] + fn book_hub_2() { +// ANCHOR: book_hub_2 +use candle::Device; +use hf_hub::api::sync::Api; +use memmap2::Mmap; +use std::fs; + +let api = Api::new().unwrap(); +let repo = api.model("bert-base-uncased".to_string()); +let weights_filename = repo.get("model.safetensors").unwrap(); + +let file = fs::File::open(weights_filename).unwrap(); +let mmap = unsafe { Mmap::map(&file).unwrap() }; +let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap(); +// ANCHOR_END: book_hub_2 + assert_eq!(weights.len(), 206); + } + + #[rustfmt::skip] + #[test] + fn book_hub_3() { +// ANCHOR: book_hub_3 +use candle::{DType, Device, Tensor}; +use hf_hub::api::sync::Api; +use memmap2::Mmap; +use safetensors::slice::IndexOp; +use safetensors::SafeTensors; +use std::fs; + +let api = Api::new().unwrap(); +let repo = api.model("bert-base-uncased".to_string()); +let weights_filename = repo.get("model.safetensors").unwrap(); + +let file = fs::File::open(weights_filename).unwrap(); +let mmap = unsafe { Mmap::map(&file).unwrap() }; + +// Use safetensors directly +let tensors = SafeTensors::deserialize(&mmap[..]).unwrap(); +let view = tensors + .tensor("bert.encoder.layer.0.attention.self.query.weight") + .unwrap(); + +// We're going to load shard with rank 1, within a world_size of 4 +// We're going to split along dimension 0 doing VIEW[start..stop, :] +let rank = 1; +let world_size = 4; +let dim = 0; +let dtype = view.dtype(); +let mut tp_shape = view.shape().to_vec(); +let size = tp_shape[0]; + +if size % world_size != 0 { + panic!("The dimension is not divisble by `world_size`"); +} +let block_size = size / world_size; +let start = rank * block_size; +let stop = (rank + 1) * block_size; + +// Everything is expressed in tensor dimension +// bytes offsets is handled automatically for safetensors. + +let iterator = view.slice(start..stop).unwrap(); + +tp_shape[dim] = block_size; + +// Convert safetensors Dtype to candle DType +let dtype: DType = dtype.try_into().unwrap(); + +// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc. +let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect(); +let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap(); +// ANCHOR_END: book_hub_3 + assert_eq!(view.shape(), &[768, 768]); + assert_eq!(tp_tensor.dims(), &[192, 768]); + } +} |