diff options
Diffstat (limited to 'candle-examples/examples')
-rw-r--r-- | candle-examples/examples/bert/main.rs | 3 | ||||
-rw-r--r-- | candle-examples/examples/bert/model.rs | 568 | ||||
-rw-r--r-- | candle-examples/examples/bigcode/main.rs | 3 | ||||
-rw-r--r-- | candle-examples/examples/bigcode/model.rs | 359 | ||||
-rw-r--r-- | candle-examples/examples/falcon/main.rs | 3 | ||||
-rw-r--r-- | candle-examples/examples/falcon/model.rs | 485 | ||||
-rw-r--r-- | candle-examples/examples/llama/main.rs | 3 | ||||
-rw-r--r-- | candle-examples/examples/llama/model.rs | 446 | ||||
-rw-r--r-- | candle-examples/examples/whisper/audio.rs | 214 | ||||
-rw-r--r-- | candle-examples/examples/whisper/main.rs | 71 | ||||
-rw-r--r-- | candle-examples/examples/whisper/model.rs | 416 | ||||
-rw-r--r-- | candle-examples/examples/whisper/multilingual.rs | 2 |
12 files changed, 28 insertions, 2545 deletions
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 6cee66ee..9d0eccdf 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -3,14 +3,13 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -mod model; +use candle_transformers::models::bert::{BertModel, Config, DTYPE}; use anyhow::{anyhow, Error as E, Result}; use candle::Tensor; use candle_nn::VarBuilder; use clap::Parser; use hf_hub::{api::sync::Api, Cache, Repo, RepoType}; -use model::{BertModel, Config, DTYPE}; use tokenizers::{PaddingParams, Tokenizer}; #[derive(Parser, Debug)] diff --git a/candle-examples/examples/bert/model.rs b/candle-examples/examples/bert/model.rs deleted file mode 100644 index 3f164a3a..00000000 --- a/candle-examples/examples/bert/model.rs +++ /dev/null @@ -1,568 +0,0 @@ -use candle::{DType, Device, Result, Tensor}; -use candle_nn::{Embedding, Module, VarBuilder}; -use serde::Deserialize; - -pub const DTYPE: DType = DType::F32; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] -#[serde(rename_all = "lowercase")] -enum HiddenAct { - Gelu, - Relu, -} - -struct HiddenActLayer { - act: HiddenAct, - span: tracing::Span, -} - -impl HiddenActLayer { - fn new(act: HiddenAct) -> Self { - let span = tracing::span!(tracing::Level::TRACE, "hidden-act"); - Self { act, span } - } - - fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> { - let _enter = self.span.enter(); - match self.act { - // TODO: The all-MiniLM-L6-v2 model uses "gelu" whereas this is "gelu_new", this explains some - // small numerical difference. - // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213 - HiddenAct::Gelu => xs.gelu(), - HiddenAct::Relu => xs.relu(), - } - } -} - -#[derive(Debug)] -pub struct Linear { - weight: Tensor, - bias: Option<Tensor>, - span: tracing::Span, -} - -impl Linear { - pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self { - let span = tracing::span!(tracing::Level::TRACE, "linear"); - Self { weight, bias, span } - } - - pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> { - let _enter = self.span.enter(); - let w = match x.dims() { - &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, - _ => self.weight.t()?, - }; - let x = x.matmul(&w)?; - match &self.bias { - None => Ok(x), - Some(bias) => x.broadcast_add(bias), - } - } -} - -#[derive(Debug)] -pub struct LayerNorm { - weight: Tensor, - bias: Tensor, - eps: f64, - span: tracing::Span, -} - -impl LayerNorm { - pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self { - let span = tracing::span!(tracing::Level::TRACE, "layer-norm"); - Self { - weight, - bias, - eps, - span, - } - } - - pub fn forward(&self, x: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - let x_dtype = x.dtype(); - let internal_dtype = match x_dtype { - DType::F16 | DType::BF16 => DType::F32, - d => d, - }; - let (_bsize, _seq_len, hidden_size) = x.dims3()?; - let x = x.to_dtype(internal_dtype)?; - let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?; - let x = x.broadcast_sub(&mean_x)?; - let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; - let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; - let x = x_normed - .to_dtype(x_dtype)? - .broadcast_mul(&self.weight)? - .broadcast_add(&self.bias)?; - Ok(x) - } -} -#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)] -#[serde(rename_all = "lowercase")] -enum PositionEmbeddingType { - #[default] - Absolute, -} - -// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1 -#[derive(Debug, Clone, PartialEq, Deserialize)] -pub struct Config { - vocab_size: usize, - hidden_size: usize, - num_hidden_layers: usize, - num_attention_heads: usize, - intermediate_size: usize, - hidden_act: HiddenAct, - hidden_dropout_prob: f64, - max_position_embeddings: usize, - type_vocab_size: usize, - initializer_range: f64, - layer_norm_eps: f64, - pad_token_id: usize, - #[serde(default)] - position_embedding_type: PositionEmbeddingType, - #[serde(default)] - use_cache: bool, - classifier_dropout: Option<f64>, - model_type: Option<String>, -} - -impl Default for Config { - fn default() -> Self { - Self { - vocab_size: 30522, - hidden_size: 768, - num_hidden_layers: 12, - num_attention_heads: 12, - intermediate_size: 3072, - hidden_act: HiddenAct::Gelu, - hidden_dropout_prob: 0.1, - max_position_embeddings: 512, - type_vocab_size: 2, - initializer_range: 0.02, - layer_norm_eps: 1e-12, - pad_token_id: 0, - position_embedding_type: PositionEmbeddingType::Absolute, - use_cache: true, - classifier_dropout: None, - model_type: Some("bert".to_string()), - } - } -} - -impl Config { - fn _all_mini_lm_l6_v2() -> Self { - // https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/blob/main/config.json - Self { - vocab_size: 30522, - hidden_size: 384, - num_hidden_layers: 6, - num_attention_heads: 12, - intermediate_size: 1536, - hidden_act: HiddenAct::Gelu, - hidden_dropout_prob: 0.1, - max_position_embeddings: 512, - type_vocab_size: 2, - initializer_range: 0.02, - layer_norm_eps: 1e-12, - pad_token_id: 0, - position_embedding_type: PositionEmbeddingType::Absolute, - use_cache: true, - classifier_dropout: None, - model_type: Some("bert".to_string()), - } - } -} - -fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} - -fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { - let weight = vb.get((size2, size1), "weight")?; - let bias = vb.get(size2, "bias")?; - Ok(Linear::new(weight, Some(bias))) -} - -struct Dropout { - #[allow(dead_code)] - pr: f64, -} - -impl Dropout { - fn new(pr: f64) -> Self { - Self { pr } - } - - fn forward(&self, x: &Tensor) -> Result<Tensor> { - // TODO - Ok(x.clone()) - } -} - -fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> { - let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) { - (Ok(weight), Ok(bias)) => (weight, bias), - (Err(err), _) | (_, Err(err)) => { - if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) { - (weight, bias) - } else { - return Err(err); - } - } - }; - Ok(LayerNorm::new(weight, bias, eps)) -} - -// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180 -struct BertEmbeddings { - word_embeddings: Embedding, - position_embeddings: Option<Embedding>, - token_type_embeddings: Embedding, - layer_norm: LayerNorm, - dropout: Dropout, - span: tracing::Span, -} - -impl BertEmbeddings { - fn load(vb: VarBuilder, config: &Config) -> Result<Self> { - let word_embeddings = embedding( - config.vocab_size, - config.hidden_size, - vb.pp("word_embeddings"), - )?; - let position_embeddings = embedding( - config.max_position_embeddings, - config.hidden_size, - vb.pp("position_embeddings"), - )?; - let token_type_embeddings = embedding( - config.type_vocab_size, - config.hidden_size, - vb.pp("token_type_embeddings"), - )?; - let layer_norm = layer_norm( - config.hidden_size, - config.layer_norm_eps, - vb.pp("LayerNorm"), - )?; - Ok(Self { - word_embeddings, - position_embeddings: Some(position_embeddings), - token_type_embeddings, - layer_norm, - dropout: Dropout::new(config.hidden_dropout_prob), - span: tracing::span!(tracing::Level::TRACE, "embeddings"), - }) - } - - fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - let (_bsize, seq_len) = input_ids.dims2()?; - let input_embeddings = self.word_embeddings.forward(input_ids)?; - let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?; - let mut embeddings = (&input_embeddings + token_type_embeddings)?; - if let Some(position_embeddings) = &self.position_embeddings { - // TODO: Proper absolute positions? - let position_ids = (0..seq_len as u32).collect::<Vec<_>>(); - let position_ids = Tensor::new(&position_ids[..], input_ids.device())?; - embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)? - } - let embeddings = self.layer_norm.forward(&embeddings)?; - let embeddings = self.dropout.forward(&embeddings)?; - Ok(embeddings) - } -} - -struct BertSelfAttention { - query: Linear, - key: Linear, - value: Linear, - dropout: Dropout, - num_attention_heads: usize, - attention_head_size: usize, - span: tracing::Span, - span_softmax: tracing::Span, -} - -impl BertSelfAttention { - fn load(vb: VarBuilder, config: &Config) -> Result<Self> { - let attention_head_size = config.hidden_size / config.num_attention_heads; - let all_head_size = config.num_attention_heads * attention_head_size; - let dropout = Dropout::new(config.hidden_dropout_prob); - let hidden_size = config.hidden_size; - let query = linear(hidden_size, all_head_size, vb.pp("query"))?; - let value = linear(hidden_size, all_head_size, vb.pp("value"))?; - let key = linear(hidden_size, all_head_size, vb.pp("key"))?; - Ok(Self { - query, - key, - value, - dropout, - num_attention_heads: config.num_attention_heads, - attention_head_size, - span: tracing::span!(tracing::Level::TRACE, "self-attn"), - span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"), - }) - } - - fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> { - let mut new_x_shape = xs.dims().to_vec(); - new_x_shape.pop(); - new_x_shape.push(self.num_attention_heads); - new_x_shape.push(self.attention_head_size); - let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?; - xs.contiguous() - } - - fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - let query_layer = self.query.forward(hidden_states)?; - let key_layer = self.key.forward(hidden_states)?; - let value_layer = self.value.forward(hidden_states)?; - - let query_layer = self.transpose_for_scores(&query_layer)?; - let key_layer = self.transpose_for_scores(&key_layer)?; - let value_layer = self.transpose_for_scores(&value_layer)?; - - let attention_scores = query_layer.matmul(&key_layer.t()?)?; - let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?; - let attention_probs = { - let _enter_sm = self.span_softmax.enter(); - candle_nn::ops::softmax(&attention_scores, candle::D::Minus1)? - }; - let attention_probs = self.dropout.forward(&attention_probs)?; - - let context_layer = attention_probs.matmul(&value_layer)?; - let context_layer = context_layer.transpose(1, 2)?.contiguous()?; - let context_layer = context_layer.flatten_from(candle::D::Minus2)?; - Ok(context_layer) - } -} - -struct BertSelfOutput { - dense: Linear, - layer_norm: LayerNorm, - dropout: Dropout, - span: tracing::Span, -} - -impl BertSelfOutput { - fn load(vb: VarBuilder, config: &Config) -> Result<Self> { - let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?; - let layer_norm = layer_norm( - config.hidden_size, - config.layer_norm_eps, - vb.pp("LayerNorm"), - )?; - let dropout = Dropout::new(config.hidden_dropout_prob); - Ok(Self { - dense, - layer_norm, - dropout, - span: tracing::span!(tracing::Level::TRACE, "self-out"), - }) - } - - fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - let hidden_states = self.dense.forward(hidden_states)?; - let hidden_states = self.dropout.forward(&hidden_states)?; - self.layer_norm.forward(&(hidden_states + input_tensor)?) - } -} - -// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392 -struct BertAttention { - self_attention: BertSelfAttention, - self_output: BertSelfOutput, - span: tracing::Span, -} - -impl BertAttention { - fn load(vb: VarBuilder, config: &Config) -> Result<Self> { - let self_attention = BertSelfAttention::load(vb.pp("self"), config)?; - let self_output = BertSelfOutput::load(vb.pp("output"), config)?; - Ok(Self { - self_attention, - self_output, - span: tracing::span!(tracing::Level::TRACE, "attn"), - }) - } - - fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - let self_outputs = self.self_attention.forward(hidden_states)?; - let attention_output = self.self_output.forward(&self_outputs, hidden_states)?; - Ok(attention_output) - } -} - -// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441 -struct BertIntermediate { - dense: Linear, - intermediate_act: HiddenActLayer, - span: tracing::Span, -} - -impl BertIntermediate { - fn load(vb: VarBuilder, config: &Config) -> Result<Self> { - let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?; - Ok(Self { - dense, - intermediate_act: HiddenActLayer::new(config.hidden_act), - span: tracing::span!(tracing::Level::TRACE, "inter"), - }) - } - - fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - let hidden_states = self.dense.forward(hidden_states)?; - let ys = self.intermediate_act.forward(&hidden_states)?; - Ok(ys) - } -} - -// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456 -struct BertOutput { - dense: Linear, - layer_norm: LayerNorm, - dropout: Dropout, - span: tracing::Span, -} - -impl BertOutput { - fn load(vb: VarBuilder, config: &Config) -> Result<Self> { - let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?; - let layer_norm = layer_norm( - config.hidden_size, - config.layer_norm_eps, - vb.pp("LayerNorm"), - )?; - let dropout = Dropout::new(config.hidden_dropout_prob); - Ok(Self { - dense, - layer_norm, - dropout, - span: tracing::span!(tracing::Level::TRACE, "out"), - }) - } - - fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - let hidden_states = self.dense.forward(hidden_states)?; - let hidden_states = self.dropout.forward(&hidden_states)?; - self.layer_norm.forward(&(hidden_states + input_tensor)?) - } -} - -// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470 -struct BertLayer { - attention: BertAttention, - intermediate: BertIntermediate, - output: BertOutput, - span: tracing::Span, -} - -impl BertLayer { - fn load(vb: VarBuilder, config: &Config) -> Result<Self> { - let attention = BertAttention::load(vb.pp("attention"), config)?; - let intermediate = BertIntermediate::load(vb.pp("intermediate"), config)?; - let output = BertOutput::load(vb.pp("output"), config)?; - Ok(Self { - attention, - intermediate, - output, - span: tracing::span!(tracing::Level::TRACE, "layer"), - }) - } - - fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - let attention_output = self.attention.forward(hidden_states)?; - // TODO: Support cross-attention? - // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523 - // TODO: Support something similar to `apply_chunking_to_forward`? - let intermediate_output = self.intermediate.forward(&attention_output)?; - let layer_output = self - .output - .forward(&intermediate_output, &attention_output)?; - Ok(layer_output) - } -} - -// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556 -struct BertEncoder { - layers: Vec<BertLayer>, - span: tracing::Span, -} - -impl BertEncoder { - fn load(vb: VarBuilder, config: &Config) -> Result<Self> { - let layers = (0..config.num_hidden_layers) - .map(|index| BertLayer::load(vb.pp(&format!("layer.{index}")), config)) - .collect::<Result<Vec<_>>>()?; - let span = tracing::span!(tracing::Level::TRACE, "encoder"); - Ok(BertEncoder { layers, span }) - } - - fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - let mut hidden_states = hidden_states.clone(); - // Use a loop rather than a fold as it's easier to modify when adding debug/... - for layer in self.layers.iter() { - hidden_states = layer.forward(&hidden_states)? - } - Ok(hidden_states) - } -} - -// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874 -pub struct BertModel { - embeddings: BertEmbeddings, - encoder: BertEncoder, - pub device: Device, - span: tracing::Span, -} - -impl BertModel { - pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> { - let (embeddings, encoder) = match ( - BertEmbeddings::load(vb.pp("embeddings"), config), - BertEncoder::load(vb.pp("encoder"), config), - ) { - (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), - (Err(err), _) | (_, Err(err)) => { - if let Some(model_type) = &config.model_type { - if let (Ok(embeddings), Ok(encoder)) = ( - BertEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config), - BertEncoder::load(vb.pp(&format!("{model_type}.encoder")), config), - ) { - (embeddings, encoder) - } else { - return Err(err); - } - } else { - return Err(err); - } - } - }; - Ok(Self { - embeddings, - encoder, - device: vb.device().clone(), - span: tracing::span!(tracing::Level::TRACE, "model"), - }) - } - - pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?; - let sequence_output = self.encoder.forward(&embedding_output)?; - Ok(sequence_output) - } -} diff --git a/candle-examples/examples/bigcode/main.rs b/candle-examples/examples/bigcode/main.rs index 652cd47f..3540f75d 100644 --- a/candle-examples/examples/bigcode/main.rs +++ b/candle-examples/examples/bigcode/main.rs @@ -7,8 +7,7 @@ extern crate accelerate_src; use anyhow::{Error as E, Result}; use clap::Parser; -mod model; -use model::{Config, GPTBigCode}; +use candle_transformers::models::bigcode::{Config, GPTBigCode}; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; diff --git a/candle-examples/examples/bigcode/model.rs b/candle-examples/examples/bigcode/model.rs deleted file mode 100644 index 1e63956b..00000000 --- a/candle-examples/examples/bigcode/model.rs +++ /dev/null @@ -1,359 +0,0 @@ -use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder}; - -fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> { - let weight = vb.get((size2, size1), "weight")?; - let bias = if bias { - Some(vb.get(size2, "bias")?) - } else { - None - }; - Ok(Linear::new(weight, bias)) -} - -fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} - -fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> { - let weight = vb.get(size, "weight")?; - let bias = vb.get(size, "bias")?; - Ok(LayerNorm::new(weight, bias, eps)) -} - -fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> { - let mask: Vec<_> = (0..t) - .flat_map(|i| (0..t).map(move |j| u8::from(j <= i))) - .collect(); - let mask = Tensor::from_slice(&mask, (t, t), device)?; - Ok(mask) -} - -#[derive(Debug)] -pub struct Config { - pub vocab_size: usize, - // max_position_embeddings aka n_positions - pub max_position_embeddings: usize, - // num_hidden_layers aka n_layer - pub num_hidden_layers: usize, - // hidden_size aka n_embd - pub hidden_size: usize, - pub layer_norm_epsilon: f64, - pub n_inner: Option<usize>, - // num_attention_heads aka n_head - pub num_attention_heads: usize, - pub multi_query: bool, - pub use_cache: bool, -} - -impl Config { - #[allow(dead_code)] - pub fn starcoder_1b() -> Self { - Self { - vocab_size: 49152, - max_position_embeddings: 8192, - num_hidden_layers: 24, - hidden_size: 2048, - layer_norm_epsilon: 1e-5, - n_inner: Some(8192), - num_attention_heads: 16, - multi_query: true, - use_cache: true, - } - } - - #[allow(dead_code)] - pub fn starcoder_3b() -> Self { - Self { - vocab_size: 49152, - max_position_embeddings: 8192, - num_hidden_layers: 36, - hidden_size: 2816, - layer_norm_epsilon: 1e-5, - n_inner: Some(11264), - num_attention_heads: 22, - multi_query: true, - use_cache: true, - } - } - - #[allow(dead_code)] - pub fn starcoder_7b() -> Self { - Self { - vocab_size: 49152, - max_position_embeddings: 8192, - num_hidden_layers: 42, - hidden_size: 4096, - layer_norm_epsilon: 1e-5, - n_inner: Some(16384), - num_attention_heads: 32, - multi_query: true, - use_cache: true, - } - } - - #[allow(dead_code)] - pub fn starcoder() -> Self { - Self { - vocab_size: 49152, - max_position_embeddings: 8192, - num_hidden_layers: 40, - hidden_size: 6144, - layer_norm_epsilon: 1e-5, - n_inner: Some(24576), - num_attention_heads: 48, - multi_query: true, - use_cache: true, - } - } -} - -struct Attention { - c_attn: Linear, - c_proj: Linear, - kv_cache: Option<Tensor>, - use_cache: bool, - embed_dim: usize, - kv_dim: usize, - num_heads: usize, - head_dim: usize, - multi_query: bool, -} - -impl Attention { - pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let hidden_size = cfg.hidden_size; - let head_dim = hidden_size / cfg.num_attention_heads; - let kv_heads = if cfg.multi_query { - 1 - } else { - cfg.num_attention_heads - }; - let kv_dim = kv_heads * head_dim; - let c_attn = linear(hidden_size, hidden_size + 2 * kv_dim, true, vb.pp("c_attn"))?; - let c_proj = linear(hidden_size, hidden_size, true, vb.pp("c_proj"))?; - Ok(Self { - c_proj, - c_attn, - embed_dim: hidden_size, - kv_cache: None, - use_cache: cfg.use_cache, - kv_dim, - head_dim, - num_heads: cfg.num_attention_heads, - multi_query: cfg.multi_query, - }) - } - - fn attn( - &self, - query: &Tensor, - key: &Tensor, - value: &Tensor, - attention_mask: &Tensor, - ) -> Result<Tensor> { - if query.dtype() != DType::F32 { - // If we start supporting f16 models, we may need the upcasting scaling bits. - // https://github.com/huggingface/transformers/blob/a0042379269bea9182c1f87e6b2eee4ba4c8cce8/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L133 - candle::bail!("upcasting is not supported {:?}", query.dtype()) - } - let scale_factor = 1f64 / (self.head_dim as f64).sqrt(); - let initial_query_shape = query.shape(); - let key_len = key.dim(D::Minus1)?; - let (query, key, attn_shape, attn_view) = if self.multi_query { - let (b_sz, query_len, _) = query.dims3()?; - let query = query.reshape((b_sz, query_len * self.num_heads, self.head_dim))?; - let attn_shape = (b_sz, query_len, self.num_heads, key_len); - let attn_view = (b_sz, query_len * self.num_heads, key_len); - (query, key.clone(), attn_shape, attn_view) - } else { - let (b_sz, _num_heads, query_len, _head_dim) = query.dims4()?; - let query = query.reshape((b_sz, query_len * self.num_heads, self.head_dim))?; - let key = key.reshape((b_sz * self.num_heads, self.head_dim, key_len))?; - let attn_shape = (b_sz, self.num_heads, query_len, key_len); - let attn_view = (b_sz * self.num_heads, query_len, key_len); - (query, key, attn_shape, attn_view) - }; - - let attn_weights = - (query.matmul(&key.contiguous()?)? * scale_factor)?.reshape(attn_shape)?; - let attention_mask = attention_mask.broadcast_as(attn_shape)?; - let mask_value = - Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?; - let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?; - let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; - let value = value.contiguous()?; - let attn_output = if self.multi_query { - attn_weights - .reshape(attn_view)? - .matmul(&value)? - .reshape(initial_query_shape)? - } else { - attn_weights.matmul(&value)? - }; - Ok(attn_output) - } - - fn forward(&mut self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> { - let qkv = self.c_attn.forward(hidden_states)?; - let (query, key_value) = if self.multi_query { - let query = qkv.i((.., .., ..self.embed_dim))?; - let key_value = qkv.i((.., .., self.embed_dim..self.embed_dim + 2 * self.kv_dim))?; - (query, key_value) - } else { - let mut dims = qkv.dims().to_vec(); - dims.pop(); - dims.push(self.embed_dim); - dims.push(self.head_dim * 3); - let qkv = qkv.reshape(dims)?.transpose(1, 2)?; - let query = qkv.i((.., .., .., ..self.head_dim))?; - let key_value = qkv.i((.., .., .., self.head_dim..3 * self.head_dim))?; - (query, key_value) - }; - let mut key_value = key_value; - if self.use_cache { - if let Some(kv_cache) = &self.kv_cache { - // TODO: we could trim the tensors to MAX_SEQ_LEN so that this would work for - // arbitrarily large sizes. - key_value = Tensor::cat(&[kv_cache, &key_value], D::Minus2)?.contiguous()?; - } - self.kv_cache = Some(key_value.clone()) - } - - let key = key_value.narrow(D::Minus1, 0, self.head_dim)?; - let value = key_value.narrow(D::Minus1, self.head_dim, self.head_dim)?; - let attn_output = self.attn(&query, &key.t()?, &value, attention_mask)?; - let attn_output = if self.multi_query { - attn_output - } else { - attn_output - .transpose(1, 2)? - .reshape(hidden_states.shape())? - }; - let attn_output = self.c_proj.forward(&attn_output)?; - Ok(attn_output) - } -} - -struct Mlp { - c_fc: Linear, - c_proj: Linear, -} - -impl Mlp { - fn load(inner_dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> { - let c_fc = linear(cfg.hidden_size, inner_dim, true, vb.pp("c_fc"))?; - let c_proj = linear(inner_dim, cfg.hidden_size, true, vb.pp("c_proj"))?; - Ok(Self { c_fc, c_proj }) - } - - fn forward(&mut self, hidden_states: &Tensor) -> Result<Tensor> { - let hidden_states = self.c_fc.forward(hidden_states)?.gelu()?; - let hidden_states = self.c_proj.forward(&hidden_states)?; - Ok(hidden_states) - } -} - -// TODO: Add cross-attention? -struct Block { - ln_1: LayerNorm, - attn: Attention, - ln_2: LayerNorm, - mlp: Mlp, -} - -impl Block { - fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let hidden_size = cfg.hidden_size; - let inner_dim = cfg.n_inner.unwrap_or(4 * hidden_size); - let ln_1 = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb.pp("ln_1"))?; - let attn = Attention::load(vb.pp("attn"), cfg)?; - let ln_2 = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb.pp("ln_2"))?; - let mlp = Mlp::load(inner_dim, vb.pp("mlp"), cfg)?; - Ok(Self { - ln_1, - attn, - ln_2, - mlp, - }) - } - - fn forward(&mut self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> { - let residual = hidden_states; - let hidden_states = self.ln_1.forward(hidden_states)?; - let attn_outputs = self.attn.forward(&hidden_states, attention_mask)?; - let hidden_states = (&attn_outputs + residual)?; - let residual = &hidden_states; - let hidden_states = self.ln_2.forward(&hidden_states)?; - let hidden_states = self.mlp.forward(&hidden_states)?; - let hidden_states = (&hidden_states + residual)?; - Ok(hidden_states) - } -} - -pub struct GPTBigCode { - wte: Embedding, - wpe: Embedding, - blocks: Vec<Block>, - ln_f: LayerNorm, - lm_head: Linear, - bias: Tensor, - config: Config, -} - -impl GPTBigCode { - pub fn config(&self) -> &Config { - &self.config - } - - pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> { - let hidden_size = cfg.hidden_size; - let vb_t = vb.pp("transformer"); - let wte = embedding(cfg.vocab_size, hidden_size, vb_t.pp("wte"))?; - let wpe = embedding(cfg.max_position_embeddings, hidden_size, vb_t.pp("wpe"))?; - let blocks = (0..cfg.num_hidden_layers) - .map(|i| Block::load(vb_t.pp(&format!("h.{i}")), &cfg)) - .collect::<Result<Vec<_>>>()?; - let ln_f = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb_t.pp("ln_f"))?; - let lm_head = linear(hidden_size, cfg.vocab_size, false, vb_t.pp("wte"))?; - let bias = make_causal_mask(cfg.max_position_embeddings, vb.device())?; - Ok(Self { - wte, - wpe, - blocks, - lm_head, - ln_f, - bias, - config: cfg, - }) - } - - pub fn forward(&mut self, input_ids: &Tensor, past_len: usize) -> Result<Tensor> { - let dev = input_ids.device(); - let (b_sz, seq_len) = input_ids.dims2()?; - - let key_len = past_len + seq_len; - let attention_mask = self.bias.i((past_len..key_len, ..key_len))?.unsqueeze(0)?; - // MQA models: (batch_size, query_length, n_heads, key_length) - // MHA models: (batch_size, n_heads, query_length, key_length) - let seq_len_dim = if self.config.multi_query { 2 } else { 1 }; - let attention_mask = attention_mask.unsqueeze(seq_len_dim)?; - - let position_ids = Tensor::arange(past_len as u32, (past_len + seq_len) as u32, dev)?; - let position_ids = position_ids.unsqueeze(0)?.broadcast_as((b_sz, seq_len))?; - let input_embeds = self.wte.forward(input_ids)?; - let position_embeds = self.wpe.forward(&position_ids)?; - - let mut hidden_states = (&input_embeds + &position_embeds)?; - for block in self.blocks.iter_mut() { - hidden_states = block.forward(&hidden_states, &attention_mask)?; - } - let hidden_states = self.ln_f.forward(&hidden_states)?; - let hidden_states = hidden_states - .reshape((b_sz, seq_len, self.config.hidden_size))? - .narrow(1, seq_len - 1, 1)?; - let logits = self.lm_head.forward(&hidden_states)?.squeeze(1)?; - Ok(logits) - } -} diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs index 05507f08..c45fe545 100644 --- a/candle-examples/examples/falcon/main.rs +++ b/candle-examples/examples/falcon/main.rs @@ -14,8 +14,7 @@ use clap::Parser; use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::Tokenizer; -mod model; -use model::{Config, Falcon}; +use candle_transformers::models::falcon::{Config, Falcon}; struct TextGeneration { model: Falcon, diff --git a/candle-examples/examples/falcon/model.rs b/candle-examples/examples/falcon/model.rs deleted file mode 100644 index b638dd51..00000000 --- a/candle-examples/examples/falcon/model.rs +++ /dev/null @@ -1,485 +0,0 @@ -use anyhow::Result; -use candle::{DType, Device, Tensor, D}; -use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder}; - -const MAX_SEQ_LEN: usize = 5000; - -fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> { - let weight = vb.get((size2, size1), "weight")?; - let bias = if bias { - Some(vb.get(size2, "bias")?) - } else { - None - }; - Ok(Linear::new(weight, bias)) -} - -fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> { - let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) { - (Ok(weight), Ok(bias)) => (weight, bias), - (Err(err), _) | (_, Err(err)) => { - if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) { - (weight, bias) - } else { - return Err(err.into()); - } - } - }; - Ok(LayerNorm::new(weight, bias, eps)) -} - -fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} - -// https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py -#[derive(Debug)] -pub struct Config { - pub vocab_size: usize, - pub hidden_size: usize, - pub num_hidden_layers: usize, - pub num_attention_heads: usize, - pub layer_norm_epsilon: f64, - pub initializer_range: f64, - pub use_cache: bool, - pub bos_token_id: u32, - pub eos_token_id: u32, - pub hidden_dropout: f64, - pub attention_dropout: f64, - pub n_head_kv: Option<usize>, - pub alibi: bool, - pub new_decoder_architecture: bool, - pub multi_query: bool, - pub parallel_attn: bool, - pub bias: bool, -} - -impl Default for Config { - fn default() -> Self { - Self { - vocab_size: 65024, - hidden_size: 4544, - num_hidden_layers: 32, - num_attention_heads: 71, - layer_norm_epsilon: 1e-5, - initializer_range: 0.02, - use_cache: true, - bos_token_id: 11, - eos_token_id: 11, - hidden_dropout: 0.0, - attention_dropout: 0.0, - n_head_kv: None, - alibi: false, - new_decoder_architecture: false, - multi_query: true, - parallel_attn: true, - bias: false, - } - } -} - -impl Config { - pub fn validate(&self) -> Result<()> { - if self.alibi { - anyhow::bail!("alibi is not supported"); - } - if self.new_decoder_architecture { - anyhow::bail!("new_decoder_architecture is not supported"); - } - if self.n_head_kv.is_some() { - anyhow::bail!("n_head_kv is not supported"); - } - Ok(()) - } - - // https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json - pub fn falcon7b() -> Self { - // This is currently on par with the defaults, the defaults come from the Python default - // arguments for the config initialization whereas the following come from the json config. - Self { - vocab_size: 65024, - hidden_size: 4544, - num_hidden_layers: 32, - num_attention_heads: 71, - layer_norm_epsilon: 1e-5, - initializer_range: 0.02, - use_cache: true, - bos_token_id: 11, - eos_token_id: 11, - hidden_dropout: 0., - attention_dropout: 0., - n_head_kv: None, - alibi: false, - new_decoder_architecture: false, - multi_query: true, - parallel_attn: true, - bias: false, - } - } - - fn head_dim(&self) -> usize { - self.hidden_size / self.num_attention_heads - } - - fn rotary(&self) -> bool { - !self.alibi - } -} - -fn rotate_half(x: &Tensor) -> Result<Tensor> { - let l = x.dim(D::Minus1)?; - let x1 = x.narrow(D::Minus1, 0, l / 2)?; - let x2 = x.narrow(D::Minus1, l / 2, l - l / 2)?; - let x21 = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?; - Ok(x21) -} - -#[derive(Debug)] -struct FalconRotaryEmbedding { - inv_freq: Tensor, - cache: Option<(usize, Tensor, Tensor)>, -} - -impl FalconRotaryEmbedding { - fn load(device: &Device, cfg: &Config) -> Result<Self> { - let head_dim = cfg.head_dim(); - let inv_freq: Vec<_> = (0..head_dim) - .step_by(2) - .map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32)) - .collect(); - Ok(Self { - inv_freq: Tensor::new(inv_freq.as_slice(), device)?, - cache: None, - }) - } - - fn cos_sin( - &mut self, - seq_len: usize, - device: &Device, - dtype: DType, - ) -> Result<(Tensor, Tensor)> { - match &self.cache { - Some((s, cos, sin)) if *s == seq_len => { - return Ok((cos.clone(), sin.clone())); - } - _ => {} - } - let t = Tensor::arange(0, seq_len as u32, device)?.to_dtype(dtype)?; - let inv_freq = self.inv_freq.to_dtype(dtype)?; - let freqs = t.unsqueeze(1)?.matmul(&inv_freq.unsqueeze(0)?)?; - let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; - let cos = emb.cos()?; - let sin = emb.sin()?; - self.cache = Some((seq_len, cos.clone(), sin.clone())); - Ok((cos, sin)) - } - - fn forward( - &mut self, - query: &Tensor, - key: &Tensor, - past_kv_len: usize, - ) -> Result<(Tensor, Tensor)> { - let (_batch, seq_len, _head_dim) = query.dims3()?; - let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, query.device(), query.dtype())?; - let cos = cos.narrow(0, past_kv_len, seq_len)?; - let sin = sin.narrow(0, past_kv_len, seq_len)?; - let qs = (query.broadcast_mul(&cos)? + &rotate_half(query)?.broadcast_mul(&sin)?)?; - let ks = (key.broadcast_mul(&cos)? + &rotate_half(key)?.broadcast_mul(&sin)?)?; - Ok((qs, ks)) - } -} - -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { - let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; - let m = mask.where_cond(&on_true, on_false)?; - Ok(m) -} - -#[derive(Debug)] -struct FalconAttention { - query_key_value: Linear, - dense: Linear, - maybe_rotary: Option<FalconRotaryEmbedding>, - kv_cache: Option<(Tensor, Tensor)>, - inv_norm_factor: f64, - multi_query: bool, - use_cache: bool, - num_heads: usize, - head_dim: usize, - n_head_kv: usize, -} - -impl FalconAttention { - fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let maybe_rotary = if cfg.rotary() { - let rotary = FalconRotaryEmbedding::load(vb.device(), cfg)?; - Some(rotary) - } else { - None - }; - let head_dim = cfg.head_dim(); - let hidden_size = cfg.hidden_size; - let qkv_out_dim = if cfg.multi_query { - hidden_size + 2 * head_dim - } else { - 3 * hidden_size - }; - let query_key_value = linear(hidden_size, qkv_out_dim, cfg.bias, vb.pp("query_key_value"))?; - let dense = linear(hidden_size, hidden_size, cfg.bias, vb.pp("dense"))?; - Ok(Self { - query_key_value, - dense, - maybe_rotary, - kv_cache: None, - inv_norm_factor: 1. / (head_dim as f64).sqrt(), - multi_query: cfg.multi_query, - use_cache: cfg.use_cache, - num_heads: cfg.num_attention_heads, - n_head_kv: cfg.n_head_kv.unwrap_or(1), - head_dim, - }) - } - - fn split_heads(&self, fused_qkv: &Tensor) -> Result<(Tensor, Tensor, Tensor)> { - let (b_sz, seq_len, _) = fused_qkv.dims3()?; - if !self.multi_query { - let fused_qkv = fused_qkv.reshape((b_sz, seq_len, self.num_heads, 3, self.head_dim))?; - let q = fused_qkv.narrow(D::Minus2, 0, 1)?.squeeze(D::Minus2)?; - let k = fused_qkv.narrow(D::Minus2, 1, 1)?.squeeze(D::Minus2)?; - let v = fused_qkv.narrow(D::Minus2, 2, 1)?.squeeze(D::Minus2)?; - Ok((q, k, v)) - } else { - let fused_qkv = - fused_qkv.reshape((b_sz, seq_len, self.num_heads + 2, self.head_dim))?; - let d = fused_qkv.dim(D::Minus2)?; - let q = fused_qkv.narrow(D::Minus2, 0, d - 2)?; - let k = fused_qkv.narrow(D::Minus2, d - 2, 1)?; - let v = fused_qkv.narrow(D::Minus2, d - 1, 1)?; - Ok((q, k, v)) - } - } - - fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result<Tensor> { - let fused_qkv = self.query_key_value.forward(x)?; - let head_dim = self.head_dim; - let (query, key, value) = self.split_heads(&fused_qkv)?; - let (b_sz, seq_len, _, _) = query.dims4()?; - let query = query - .transpose(1, 2)? - .reshape((b_sz * self.num_heads, seq_len, head_dim))?; - let key = key - .transpose(1, 2)? - .reshape((b_sz * self.n_head_kv, seq_len, head_dim))?; - let value = value - .transpose(1, 2)? - .reshape((b_sz * self.n_head_kv, seq_len, head_dim))?; - let (query, key) = if let Some(r) = &mut self.maybe_rotary { - r.forward(&query, &key, past_kv_len)? - } else { - (query, key) - }; - let (mut key, mut value) = (key, value); - let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)?.to_dtype(query.dtype())?; - if self.use_cache { - if let Some((cache_k, cache_v)) = &self.kv_cache { - // TODO: we could trim the tensors to MAX_SEQ_LEN so that this would work for - // arbitrarily large sizes. - key = Tensor::cat(&[cache_k, &key], 1)?.contiguous()?; - value = Tensor::cat(&[cache_v, &value], 1)?.contiguous()?; - } - self.kv_cache = Some((key.clone(), value.clone())) - } - let query = query.reshape((b_sz * self.num_heads, seq_len, head_dim))?; - let all_len = past_kv_len + seq_len; - let key = key.reshape((b_sz * self.n_head_kv, all_len, head_dim))?; - let value = value.reshape((b_sz * self.n_head_kv, all_len, head_dim))?; - - let (key, value) = if self.n_head_kv == 1 { - ( - key.broadcast_as((b_sz * self.num_heads, all_len, head_dim))?, - value.broadcast_as((b_sz * self.num_heads, all_len, head_dim))?, - ) - } else { - (key, value) - }; - - // Only handle the case where alibi is None here, and non-flash attention. - let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?; - let attention_scores = candle_nn::ops::softmax( - &attention_scores - .broadcast_add(&mask.squeeze(1)?)? - .to_dtype(DType::F32)?, - D::Minus1, - )? - .to_dtype(x.dtype())?; - let attn_output = attention_scores - .matmul(&value)? - .reshape((b_sz, self.num_heads, seq_len, head_dim))? - .transpose(1, 2)? - .reshape((b_sz, seq_len, self.num_heads * head_dim))?; - let attn_output = self.dense.forward(&attn_output)?; - Ok(attn_output) - } -} - -#[derive(Debug)] -struct FalconMlp { - dense_h_to_4h: Linear, - dense_4h_to_h: Linear, -} - -impl FalconMlp { - fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let h = cfg.hidden_size; - let b = cfg.bias; - let dense_h_to_4h = linear(h, 4 * h, b, vb.pp("dense_h_to_4h"))?; - let dense_4h_to_h = linear(4 * h, h, b, vb.pp("dense_4h_to_h"))?; - Ok(Self { - dense_h_to_4h, - dense_4h_to_h, - }) - } - - fn forward(&self, x: &Tensor) -> Result<Tensor> { - let x = self.dense_h_to_4h.forward(x)?.gelu()?; - let x = self.dense_4h_to_h.forward(&x)?; - Ok(x) - } -} - -#[derive(Debug)] -struct FalconDecoderLayer { - inp_layernorm: LayerNorm, - self_attention: FalconAttention, - post_attention_layernorm: Option<LayerNorm>, - mlp: FalconMlp, - parallel_attn: bool, -} - -impl FalconDecoderLayer { - fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let mlp = FalconMlp::load(vb.pp("mlp"), cfg)?; - let inp_layernorm = layer_norm( - cfg.hidden_size, - cfg.layer_norm_epsilon, - vb.pp("input_layernorm"), - )?; - let self_attention = FalconAttention::load(vb.pp("self_attention"), cfg)?; - let post_attention_layernorm = if cfg.parallel_attn { - None - } else { - let ln = layer_norm( - cfg.hidden_size, - cfg.layer_norm_epsilon, - vb.pp("post_attention_layernorm"), - )?; - Some(ln) - }; - Ok(Self { - inp_layernorm, - self_attention, - post_attention_layernorm, - mlp, - parallel_attn: cfg.parallel_attn, - }) - } - - fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result<Tensor> { - let residual = x.clone(); - let ln_attn = self.inp_layernorm.forward(x)?; - let attn_output = self.self_attention.forward(&ln_attn, mask, past_kv_len)?; - let (residual, ln_mlp) = match &self.post_attention_layernorm { - None => (residual, ln_attn), - Some(pal) => { - // This should include some dropout. - let residual = (&attn_output + &residual)?; - let ln_mlp = pal.forward(&residual)?; - (residual, ln_mlp) - } - }; - let mlp_output = self.mlp.forward(&ln_mlp)?; - - let mlp_output = if self.parallel_attn { - (mlp_output + attn_output)? - } else { - mlp_output - }; - let output = (mlp_output + residual)?; - Ok(output) - } -} - -#[derive(Debug)] -pub struct Falcon { - word_embeddings: Embedding, - blocks: Vec<FalconDecoderLayer>, - ln_f: LayerNorm, - lm_head: Linear, - config: Config, -} - -fn make_causal_mask(t: usize) -> Result<Tensor> { - let mask: Vec<_> = (0..t) - .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) - .collect(); - let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?; - Ok(mask) -} - -fn prepare_attn_mask(b_sz: usize, seq_len: usize) -> Result<Tensor> { - // let mask = Tensor::ones((b_sz, seq_len), DType::U32, &Device::Cpu)?; - let mask = make_causal_mask(seq_len)?; - let mask = mask.broadcast_as((b_sz, 1, seq_len, seq_len))?; - Ok(mask) -} - -impl Falcon { - pub fn config(&self) -> &Config { - &self.config - } - - pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> { - let word_embeddings = embedding( - cfg.vocab_size, - cfg.hidden_size, - vb.pp("transformer.word_embeddings"), - )?; - let blocks = (0..cfg.num_hidden_layers) - .map(|i| FalconDecoderLayer::load(vb.pp(&format!("transformer.h.{i}")), &cfg)) - .collect::<Result<Vec<_>>>()?; - let ln_f = layer_norm( - cfg.hidden_size, - cfg.layer_norm_epsilon, - vb.pp("transformer.ln_f"), - )?; - let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))?; - Ok(Self { - word_embeddings, - blocks, - ln_f, - lm_head, - config: cfg, - }) - } - - pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> { - let (b_sz, seq_len) = input_ids.dims2()?; - let mut hidden_state = self.word_embeddings.forward(input_ids)?; - let past_kv_len = match &self.blocks[0].self_attention.kv_cache { - Some((k, _)) => k.dim(1)?, - None => 0, - }; - let causal_mask = prepare_attn_mask(b_sz, seq_len)?.to_device(input_ids.device())?; - for block in self.blocks.iter_mut() { - hidden_state = block.forward(&hidden_state, &causal_mask, past_kv_len)?; - } - let hidden_state = self.ln_f.forward(&hidden_state)?; - let hidden_state = hidden_state.narrow(1, seq_len - 1, 1)?; - let logits = self.lm_head.forward(&hidden_state)?.squeeze(1)?; - Ok(logits) - } -} diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 6f8766d4..db3d216c 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -21,11 +21,10 @@ use candle_transformers::generation::LogitsProcessor; use hf_hub::{api::sync::Api, Repo, RepoType}; use std::io::Write; -mod model; +use candle_transformers::models::llama as model; use model::{Config, Llama, LlamaConfig}; const EOS_TOKEN: &str = "</s>"; -const MAX_SEQ_LEN: usize = 4096; const DEFAULT_PROMPT: &str = "My favorite theorem is "; #[derive(Parser, Debug)] diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs deleted file mode 100644 index 275856e0..00000000 --- a/candle-examples/examples/llama/model.rs +++ /dev/null @@ -1,446 +0,0 @@ -use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{Embedding, Module, VarBuilder}; -use serde::Deserialize; -use std::collections::HashMap; -use std::sync::{Arc, Mutex}; - -use super::MAX_SEQ_LEN; - -#[derive(Deserialize)] -pub struct LlamaConfig { - pub hidden_size: usize, - pub intermediate_size: usize, - pub vocab_size: usize, - pub num_hidden_layers: usize, - pub num_attention_heads: usize, - pub num_key_value_heads: Option<usize>, - pub rms_norm_eps: f64, - #[serde(default = "default_rope")] - pub rope_theta: f32, -} - -fn default_rope() -> f32 { - 10_000.0 -} - -impl LlamaConfig { - pub fn into_config(self, use_flash_attn: bool) -> Config { - Config { - hidden_size: self.hidden_size, - intermediate_size: self.intermediate_size, - vocab_size: self.vocab_size, - num_hidden_layers: self.num_hidden_layers, - num_attention_heads: self.num_attention_heads, - num_key_value_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads), - rms_norm_eps: self.rms_norm_eps, - rope_theta: self.rope_theta, - use_flash_attn, - } - } -} - -pub struct Config { - pub hidden_size: usize, - pub intermediate_size: usize, - pub vocab_size: usize, - pub num_hidden_layers: usize, - pub num_attention_heads: usize, - pub num_key_value_heads: usize, - pub use_flash_attn: bool, - pub rms_norm_eps: f64, - pub rope_theta: f32, -} - -impl Config { - pub fn config_7b_v1(use_flash_attn: bool) -> Self { - Self { - hidden_size: 4096, - intermediate_size: 11008, - vocab_size: 32000, - num_hidden_layers: 32, - num_attention_heads: 32, - num_key_value_heads: 32, - use_flash_attn, - rms_norm_eps: 1e-6, - rope_theta: 10_000.0, - } - } - - pub fn config_7b_v2(use_flash_attn: bool) -> Self { - Self { - hidden_size: 4096, - intermediate_size: 11008, - vocab_size: 32000, - num_hidden_layers: 32, - num_attention_heads: 32, - num_key_value_heads: 32, - use_flash_attn, - rms_norm_eps: 1e-5, - rope_theta: 10_000.0, - } - } -} - -// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting -// model. -#[derive(Debug)] -pub struct Linear { - inner: candle_nn::Linear, - span: tracing::Span, -} - -impl Linear { - fn forward(&self, x: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - -#[derive(Clone)] -pub struct Cache { - masks: Arc<Mutex<HashMap<usize, Tensor>>>, - pub use_kv_cache: bool, - #[allow(clippy::type_complexity)] - kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>, - cos: Tensor, - sin: Tensor, - device: Device, -} - -impl Cache { - pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> { - // precompute freqs_cis - let n_elem = config.hidden_size / config.num_attention_heads; - let theta: Vec<_> = (0..n_elem) - .step_by(2) - .map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32)) - .collect(); - let theta = Tensor::new(theta.as_slice(), device)?; - let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? - .to_dtype(DType::F32)? - .reshape((MAX_SEQ_LEN, 1))? - .matmul(&theta.reshape((1, theta.elem_count()))?)?; - // This is different from the paper, see: - // https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112 - let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?; - let cos = idx_theta.cos()?.to_dtype(dtype)?; - let sin = idx_theta.sin()?.to_dtype(dtype)?; - Ok(Self { - masks: Arc::new(Mutex::new(HashMap::new())), - use_kv_cache, - kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])), - device: device.clone(), - cos, - sin, - }) - } - - fn mask(&self, t: usize) -> Result<Tensor> { - let mut masks = self.masks.lock().unwrap(); - if let Some(mask) = masks.get(&t) { - Ok(mask.clone()) - } else { - let mask: Vec<_> = (0..t) - .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) - .collect(); - let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; - masks.insert(t, mask.clone()); - Ok(mask) - } - } -} - -fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { - let span = tracing::span!(tracing::Level::TRACE, "linear"); - let inner = candle_nn::linear_no_bias(size1, size2, vb)?; - Ok(Linear { inner, span }) -} - -fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> { - let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?; - Ok(Embedding::new(embeddings, cfg.hidden_size)) -} - -struct RmsNorm { - inner: candle_nn::RmsNorm, - span: tracing::Span, -} - -impl RmsNorm { - fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> { - let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); - let inner = candle_nn::rms_norm(size, eps, vb)?; - Ok(Self { inner, span }) - } - - fn forward(&self, x: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - -struct CausalSelfAttention { - q_proj: Linear, - k_proj: Linear, - v_proj: Linear, - o_proj: Linear, - num_attention_heads: usize, - num_key_value_heads: usize, - head_dim: usize, - cache: Cache, - use_flash_attn: bool, - span: tracing::Span, - span_rot: tracing::Span, -} - -#[cfg(feature = "flash-attn")] -fn flash_attn( - q: &Tensor, - k: &Tensor, - v: &Tensor, - softmax_scale: f32, - causal: bool, -) -> Result<Tensor> { - candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) -} - -#[cfg(not(feature = "flash-attn"))] -fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> { - unimplemented!("compile with '--features flash-attn'") -} - -impl CausalSelfAttention { - fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { - let _enter = self.span_rot.enter(); - let (b_sz, _, seq_len, hidden_size) = x.dims4()?; - let cos = self.cache.cos.narrow(0, index_pos, seq_len)?; - let sin = self.cache.sin.narrow(0, index_pos, seq_len)?; - let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?; - let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?; - let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?; - let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?; - let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?; - let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?; - Ok(rope) - } - - fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> { - let _enter = self.span.enter(); - let (b_sz, seq_len, hidden_size) = x.dims3()?; - let q = self.q_proj.forward(x)?; - let k = self.k_proj.forward(x)?; - let v = self.v_proj.forward(x)?; - - let q = q - .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))? - .transpose(1, 2)?; - let k = k - .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? - .transpose(1, 2)?; - let mut v = v - .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? - .transpose(1, 2)?; - - let q = self.apply_rotary_emb(&q, index_pos)?; - let mut k = self.apply_rotary_emb(&k, index_pos)?; - - if self.cache.use_kv_cache { - let mut cache = self.cache.kvs.lock().unwrap(); - if let Some((cache_k, cache_v)) = &cache[block_idx] { - k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?; - v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?; - let k_seq_len = k.dims()[1]; - if k_seq_len > MAX_SEQ_LEN { - k = k - .narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? - .contiguous()? - } - let v_seq_len = v.dims()[1]; - if v_seq_len > 2 * MAX_SEQ_LEN { - v = v - .narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? - .contiguous()? - } - } - cache[block_idx] = Some((k.clone(), v.clone())) - } - - let k = self.repeat_kv(k)?; - let v = self.repeat_kv(v)?; - - let y = if self.use_flash_attn { - // flash-attn expects (b_sz, seq_len, nheads, head_dim) - let q = q.transpose(1, 2)?; - let k = k.transpose(1, 2)?; - let v = v.transpose(1, 2)?; - let softmax_scale = 1f32 / (self.head_dim as f32).sqrt(); - flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?.transpose(1, 2)? - } else { - let in_dtype = q.dtype(); - let q = q.to_dtype(DType::F32)?; - let k = k.to_dtype(DType::F32)?; - let v = v.to_dtype(DType::F32)?; - let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; - let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?; - let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; - let att = candle_nn::ops::softmax(&att, D::Minus1)?; - // Convert to contiguous as matmul doesn't support strided vs for now. - att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)? - }; - let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?; - let y = self.o_proj.forward(&y)?; - Ok(y) - } - - fn repeat_kv(&self, x: Tensor) -> Result<Tensor> { - let n_rep = self.num_attention_heads / self.num_key_value_heads; - if n_rep == 1 { - Ok(x) - } else { - let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?; - let x = x - .unsqueeze(2)? - .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))? - .reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?; - Ok(x) - } - } - - fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { - let span = tracing::span!(tracing::Level::TRACE, "attn"); - let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); - let size_in = cfg.hidden_size; - let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads; - let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads; - let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?; - let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?; - let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?; - let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?; - Ok(Self { - q_proj, - k_proj, - v_proj, - o_proj, - num_attention_heads: cfg.num_attention_heads, - num_key_value_heads: cfg.num_key_value_heads, - head_dim: cfg.hidden_size / cfg.num_attention_heads, - cache: cache.clone(), - use_flash_attn: cfg.use_flash_attn, - span, - span_rot, - }) - } -} - -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { - let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; - let m = mask.where_cond(&on_true, on_false)?; - Ok(m) -} - -struct Mlp { - c_fc1: Linear, - c_fc2: Linear, - c_proj: Linear, - span: tracing::Span, -} - -impl Mlp { - fn forward(&self, x: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - let x = (candle_nn::ops::silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; - self.c_proj.forward(&x) - } - - fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let span = tracing::span!(tracing::Level::TRACE, "mlp"); - let h_size = cfg.hidden_size; - let i_size = cfg.intermediate_size; - let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?; - let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?; - let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?; - Ok(Self { - c_fc1, - c_fc2, - c_proj, - span, - }) - } -} - -struct Block { - rms_1: RmsNorm, - attn: CausalSelfAttention, - rms_2: RmsNorm, - mlp: Mlp, - span: tracing::Span, -} - -impl Block { - fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> { - let _enter = self.span.enter(); - let residual = x; - let x = self.rms_1.forward(x)?; - let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?; - let residual = &x; - let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; - Ok(x) - } - - fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { - let span = tracing::span!(tracing::Level::TRACE, "block"); - let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?; - let mlp = Mlp::load(vb.pp("mlp"), cfg)?; - let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; - let rms_2 = RmsNorm::load( - cfg.hidden_size, - cfg.rms_norm_eps, - vb.pp("post_attention_layernorm"), - )?; - Ok(Self { - rms_1, - attn, - rms_2, - mlp, - span, - }) - } -} - -pub struct Llama { - wte: Embedding, - blocks: Vec<Block>, - ln_f: RmsNorm, - lm_head: Linear, -} - -impl Llama { - pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { - let (_b_sz, seq_len) = x.dims2()?; - let mut x = self.wte.forward(x)?; - for (block_idx, block) in self.blocks.iter().enumerate() { - x = block.forward(&x, index_pos, block_idx)?; - } - let x = self.ln_f.forward(&x)?; - let x = x.i((.., seq_len - 1, ..))?; - let logits = self.lm_head.forward(&x)?; - logits.to_dtype(DType::F32) - } - - pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { - let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; - let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; - let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?; - let blocks: Vec<_> = (0..cfg.num_hidden_layers) - .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap()) - .collect(); - - Ok(Self { - wte, - blocks, - ln_f, - lm_head, - }) - } -} diff --git a/candle-examples/examples/whisper/audio.rs b/candle-examples/examples/whisper/audio.rs deleted file mode 100644 index 2ceed065..00000000 --- a/candle-examples/examples/whisper/audio.rs +++ /dev/null @@ -1,214 +0,0 @@ -// Audio processing code, adapted from whisper.cpp -// https://github.com/ggerganov/whisper.cpp - -pub trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {} - -impl Float for f32 {} -impl Float for f64 {} - -// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2357 -fn fft<T: Float>(inp: &[T]) -> Vec<T> { - let n = inp.len(); - let zero = T::zero(); - if n == 1 { - return vec![inp[0], zero]; - } - if n % 2 == 1 { - return dft(inp); - } - let mut out = vec![zero; n * 2]; - - let mut even = Vec::with_capacity(n / 2); - let mut odd = Vec::with_capacity(n / 2); - - for (i, &inp) in inp.iter().enumerate() { - if i % 2 == 0 { - even.push(inp) - } else { - odd.push(inp); - } - } - - let even_fft = fft(&even); - let odd_fft = fft(&odd); - - let two_pi = T::PI() + T::PI(); - let n_t = T::from(n).unwrap(); - for k in 0..n / 2 { - let k_t = T::from(k).unwrap(); - let theta = two_pi * k_t / n_t; - let re = theta.cos(); - let im = -theta.sin(); - - let re_odd = odd_fft[2 * k]; - let im_odd = odd_fft[2 * k + 1]; - - out[2 * k] = even_fft[2 * k] + re * re_odd - im * im_odd; - out[2 * k + 1] = even_fft[2 * k + 1] + re * im_odd + im * re_odd; - - out[2 * (k + n / 2)] = even_fft[2 * k] - re * re_odd + im * im_odd; - out[2 * (k + n / 2) + 1] = even_fft[2 * k + 1] - re * im_odd - im * re_odd; - } - out -} - -// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2337 -fn dft<T: Float>(inp: &[T]) -> Vec<T> { - let zero = T::zero(); - let n = inp.len(); - let two_pi = T::PI() + T::PI(); - - let mut out = Vec::new(); - out.reserve(2 * n); - let n_t = T::from(n).unwrap(); - for k in 0..n { - let k_t = T::from(k).unwrap(); - let mut re = zero; - let mut im = zero; - - for (j, &inp) in inp.iter().enumerate() { - let j_t = T::from(j).unwrap(); - let angle = two_pi * k_t * j_t / n_t; - re += inp * angle.cos(); - im -= inp * angle.sin(); - } - - out.push(re); - out.push(im); - } - out -} - -#[allow(clippy::too_many_arguments)] -// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2414 -fn log_mel_spectrogram_w<T: Float>( - ith: usize, - hann: &[T], - samples: &[T], - filters: &[T], - fft_size: usize, - fft_step: usize, - speed_up: bool, - n_len: usize, - n_mel: usize, - n_threads: usize, -) -> Vec<T> { - let n_fft = if speed_up { - 1 + fft_size / 4 - } else { - 1 + fft_size / 2 - }; - - let zero = T::zero(); - let half = T::from(0.5).unwrap(); - let mut fft_in = vec![zero; fft_size]; - let mut mel = vec![zero; n_len * n_mel]; - - for i in (ith..n_len).step_by(n_threads) { - let offset = i * fft_step; - - // apply Hanning window - for j in 0..fft_size { - fft_in[j] = if offset + j < samples.len() { - hann[j] * samples[offset + j] - } else { - zero - } - } - - // FFT -> mag^2 - let mut fft_out: Vec<T> = fft(&fft_in); - - for j in 0..fft_size { - fft_out[j] = fft_out[2 * j] * fft_out[2 * j] + fft_out[2 * j + 1] * fft_out[2 * j + 1]; - } - for j in 1..fft_size / 2 { - let v = fft_out[fft_size - j]; - fft_out[j] += v; - } - - if speed_up { - // scale down in the frequency domain results in a speed up in the time domain - for j in 0..n_fft { - fft_out[j] = half * (fft_out[2 * j] + fft_out[2 * j + 1]); - } - } - - // mel spectrogram - for j in 0..n_mel { - let mut sum = zero; - for k in 0..n_fft { - sum += fft_out[k] * filters[j * n_fft + k]; - } - mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10(); - } - } - mel -} - -fn log_mel_spectrogram_<T: Float + std::fmt::Display>( - samples: &[T], - filters: &[T], - fft_size: usize, - fft_step: usize, - n_mel: usize, - speed_up: bool, -) -> Vec<T> { - let zero = T::zero(); - let two_pi = T::PI() + T::PI(); - let half = T::from(0.5).unwrap(); - let one = T::from(1.0).unwrap(); - let four = T::from(4.0).unwrap(); - let fft_size_t = T::from(fft_size).unwrap(); - - let hann: Vec<T> = (0..fft_size) - .map(|i| half * (one - ((two_pi * T::from(i).unwrap()) / fft_size_t).cos())) - .collect(); - let n_len = samples.len() / fft_step; - - // pad audio with at least one extra chunk of zeros - let pad = 100 * super::CHUNK_LENGTH / 2; - let n_len = if n_len % pad != 0 { - (n_len / pad + 1) * pad - } else { - n_len - }; - let n_len = n_len + pad; - let samples = { - let mut samples_padded = samples.to_vec(); - let to_add = n_len * fft_step - samples.len(); - samples_padded.extend(std::iter::repeat(zero).take(to_add)); - samples_padded - }; - - // Use a single thread for now. - let mut mel = log_mel_spectrogram_w( - 0, &hann, &samples, filters, fft_size, fft_step, speed_up, n_len, n_mel, 1, - ); - let mmax = mel - .iter() - .max_by(|&u, &v| u.partial_cmp(v).unwrap_or(std::cmp::Ordering::Greater)) - .copied() - .unwrap_or(zero) - - T::from(8).unwrap(); - for m in mel.iter_mut() { - let v = T::max(*m, mmax); - *m = v / four + one - } - mel -} - -pub fn pcm_to_mel<T: Float + std::fmt::Display>( - samples: &[T], - filters: &[T], -) -> anyhow::Result<Vec<T>> { - let mel = log_mel_spectrogram_( - samples, - filters, - super::N_FFT, - super::HOP_LENGTH, - super::N_MELS, - false, - ); - Ok(mel) -} diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index dbe9cc8d..c71d562a 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -10,41 +10,16 @@ extern crate accelerate_src; extern crate intel_mkl_src; use anyhow::{Error as E, Result}; -use candle::{DType, Device, IndexOp, Tensor}; +use candle::{Device, IndexOp, Tensor}; use candle_nn::{ops::softmax, VarBuilder}; use clap::{Parser, ValueEnum}; use hf_hub::{api::sync::Api, Repo, RepoType}; use rand::{distributions::Distribution, SeedableRng}; use tokenizers::Tokenizer; -mod audio; -mod model; -use model::{Config, Whisper}; mod multilingual; - -const DTYPE: DType = DType::F32; - -// Audio parameters. -const SAMPLE_RATE: usize = 16000; -const N_FFT: usize = 400; -const N_MELS: usize = 80; -const HOP_LENGTH: usize = 160; -const CHUNK_LENGTH: usize = 30; -const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk -const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input - -const NO_SPEECH_THRESHOLD: f64 = 0.6; -const LOGPROB_THRESHOLD: f64 = -1.0; -const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]; -const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4; - -// Tokenizer dependent bits. -const SOT_TOKEN: &str = "<|startoftranscript|>"; -const TRANSCRIBE_TOKEN: &str = "<|transcribe|>"; -const TRANSLATE_TOKEN: &str = "<|translate|>"; -const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>"; -const EOT_TOKEN: &str = "<|endoftext|>"; -const NO_SPEECH_TOKEN: &str = "<|nocaptions|>"; +use candle_transformers::models::whisper::{self as m, audio, model}; +use model::{Config, Whisper}; #[allow(dead_code)] #[derive(Debug, Clone)] @@ -94,7 +69,7 @@ impl Decoder { timestamps: bool, verbose: bool, ) -> Result<Self> { - let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?; + let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?; // Suppress the notimestamps token when in timestamps mode. // https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452 let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32) @@ -109,11 +84,11 @@ impl Decoder { }) .collect(); let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?; - let sot_token = token_id(&tokenizer, SOT_TOKEN)?; - let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?; - let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?; - let eot_token = token_id(&tokenizer, EOT_TOKEN)?; - let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?; + let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?; + let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?; + let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?; + let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?; + let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?; Ok(Self { model, rng: rand::rngs::StdRng::seed_from_u64(seed), @@ -220,17 +195,17 @@ impl Decoder { } fn decode_with_fallback(&mut self, segment: &Tensor) -> Result<DecodingResult> { - for (i, &t) in TEMPERATURES.iter().enumerate() { + for (i, &t) in m::TEMPERATURES.iter().enumerate() { let dr: Result<DecodingResult> = self.decode(segment, t); - if i == TEMPERATURES.len() - 1 { + if i == m::TEMPERATURES.len() - 1 { return dr; } // On errors, we try again with a different temperature. match dr { Ok(dr) => { - let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD - || dr.avg_logprob < LOGPROB_THRESHOLD; - if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD { + let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD + || dr.avg_logprob < m::LOGPROB_THRESHOLD; + if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD { return Ok(dr); } } @@ -248,13 +223,13 @@ impl Decoder { let mut segments = vec![]; while seek < content_frames { let start = std::time::Instant::now(); - let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; - let segment_size = usize::min(content_frames - seek, N_FRAMES); + let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64; + let segment_size = usize::min(content_frames - seek, m::N_FRAMES); let mel_segment = mel.narrow(2, seek, segment_size)?; - let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; + let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64; let dr = self.decode_with_fallback(&mel_segment)?; seek += segment_size; - if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD { + if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD { println!("no speech detected, skipping {seek} {dr:?}"); continue; } @@ -492,8 +467,8 @@ fn main() -> Result<()> { let mut input = std::fs::File::open(input)?; let (header, data) = wav::read(&mut input)?; println!("loaded wav data: {header:?}"); - if header.sampling_rate != SAMPLE_RATE as u32 { - anyhow::bail!("wav file must have a {} sampling rate", SAMPLE_RATE) + if header.sampling_rate != m::SAMPLE_RATE as u32 { + anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE) } let data = data.as_sixteen().expect("expected 16 bit wav file"); let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize] @@ -501,14 +476,14 @@ fn main() -> Result<()> { .map(|v| *v as f32 / 32768.) .collect(); println!("pcm data loaded {}", pcm_data.len()); - let mel = audio::pcm_to_mel(&pcm_data, &mel_filters)?; + let mel = audio::pcm_to_mel(&pcm_data, &mel_filters); let mel_len = mel.len(); - let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?; + let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?; println!("loaded mel: {:?}", mel.dims()); let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? }; let weights = weights.deserialize()?; - let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device); + let vb = VarBuilder::from_safetensors(vec![weights], m::DTYPE, &device); let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?; let mut model = Whisper::load(&vb, config)?; diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs deleted file mode 100644 index e58ab2ca..00000000 --- a/candle-examples/examples/whisper/model.rs +++ /dev/null @@ -1,416 +0,0 @@ -use candle::{Device, IndexOp, Result, Tensor, D}; -use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; -use serde::Deserialize; - -// The names in comments correspond to the original implementation: -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17 -#[derive(Debug, Clone, PartialEq, Deserialize)] -pub struct Config { - pub num_mel_bins: usize, // n_mels - pub max_source_positions: usize, // n_audio_ctx - pub d_model: usize, // n_audio_state - pub encoder_attention_heads: usize, // n_audio_head - pub encoder_layers: usize, // n_audio_layer - pub vocab_size: usize, // n_vocab - pub max_target_positions: usize, // n_text_ctx - // pub n_text_state: usize, - pub decoder_attention_heads: usize, // n_text_head - pub decoder_layers: usize, // n_text_layer - pub suppress_tokens: Vec<u32>, -} - -fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} -// -// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting -// model. -#[derive(Debug)] -pub struct Linear { - inner: candle_nn::Linear, - span: tracing::Span, -} - -impl Linear { - fn forward(&self, x: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - -fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { - let span = tracing::span!(tracing::Level::TRACE, "linear"); - let inner = candle_nn::linear(size1, size2, vb)?; - Ok(Linear { inner, span }) -} - -fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { - let span = tracing::span!(tracing::Level::TRACE, "linear"); - let inner = candle_nn::linear_no_bias(size1, size2, vb)?; - Ok(Linear { inner, span }) -} - -fn conv1d( - in_channels: usize, - out_channels: usize, - kernel_size: usize, - config: Conv1dConfig, - vb: VarBuilder, -) -> Result<Conv1d> { - let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?; - let bias = vb.get(out_channels, "bias")?; - Ok(Conv1d::new(weight, Some(bias), config)) -} - -fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> { - let weight = vb.get(size, "weight")?; - let bias = vb.get(size, "bias")?; - Ok(LayerNorm::new(weight, bias, 1e-5)) -} - -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62 -struct MultiHeadAttention { - query: Linear, - key: Linear, - value: Linear, - out: Linear, - n_head: usize, - span: tracing::Span, - softmax_span: tracing::Span, - matmul_span: tracing::Span, - kv_cache: Option<(Tensor, Tensor)>, -} - -impl MultiHeadAttention { - fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> { - let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn"); - let softmax_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-softmax"); - let matmul_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-matmul"); - let query = linear(n_state, n_state, vb.pp("q_proj"))?; - let value = linear(n_state, n_state, vb.pp("v_proj"))?; - let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?; - let out = linear(n_state, n_state, vb.pp("out_proj"))?; - Ok(Self { - query, - key, - value, - out, - n_head, - span, - softmax_span, - matmul_span, - kv_cache: None, - }) - } - - fn forward( - &mut self, - x: &Tensor, - xa: Option<&Tensor>, - mask: Option<&Tensor>, - flush_cache: bool, - ) -> Result<Tensor> { - let _enter = self.span.enter(); - let q = self.query.forward(x)?; - let (k, v) = match xa { - None => { - let k = self.key.forward(x)?; - let v = self.value.forward(x)?; - (k, v) - } - Some(x) => { - if flush_cache { - self.kv_cache = None; - } - if let Some((k, v)) = &self.kv_cache { - (k.clone(), v.clone()) - } else { - let k = self.key.forward(x)?; - let v = self.value.forward(x)?; - self.kv_cache = Some((k.clone(), v.clone())); - (k, v) - } - } - }; - let wv = self.qkv_attention(&q, &k, &v, mask)?; - let out = self.out.forward(&wv)?; - Ok(out) - } - - fn reshape_head(&self, x: &Tensor) -> Result<Tensor> { - let (n_batch, n_ctx, n_state) = x.dims3()?; - let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head]; - x.reshape(target_dims)?.transpose(1, 2) - } - - fn qkv_attention( - &self, - q: &Tensor, - k: &Tensor, - v: &Tensor, - mask: Option<&Tensor>, - ) -> Result<Tensor> { - let (_, n_ctx, n_state) = q.dims3()?; - let scale = ((n_state / self.n_head) as f64).powf(-0.25); - let q = (self.reshape_head(q)? * scale)?; - let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?; - let v = self.reshape_head(v)?.contiguous()?; - let mut qk = { - let _enter = self.matmul_span.enter(); - q.matmul(&k)? - }; - if let Some(mask) = mask { - let mask = mask.i((0..n_ctx, 0..n_ctx))?; - qk = qk.broadcast_add(&mask)? - } - let w = { - let _enter = self.softmax_span.enter(); - softmax(&qk, D::Minus1)? - }; - let wv = { - let _enter = self.matmul_span.enter(); - w.matmul(&v)? - } - .transpose(1, 2)? - .flatten_from(2)?; - Ok(wv) - } -} - -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111 -struct ResidualAttentionBlock { - attn: MultiHeadAttention, - attn_ln: LayerNorm, - cross_attn: Option<(MultiHeadAttention, LayerNorm)>, - mlp_linear1: Linear, - mlp_linear2: Linear, - mlp_ln: LayerNorm, - span: tracing::Span, -} - -impl ResidualAttentionBlock { - fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> { - let span = tracing::span!(tracing::Level::TRACE, "residual-attn"); - let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?; - let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?; - let cross_attn = if ca { - let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?; - let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?; - Some((cross_attn, cross_attn_ln)) - } else { - None - }; - let n_mlp = n_state * 4; - let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?; - let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?; - let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?; - Ok(Self { - attn, - attn_ln, - cross_attn, - mlp_linear1, - mlp_linear2, - mlp_ln, - span, - }) - } - - fn forward( - &mut self, - x: &Tensor, - xa: Option<&Tensor>, - mask: Option<&Tensor>, - flush_kv_cache: bool, - ) -> Result<Tensor> { - let _enter = self.span.enter(); - let attn = self - .attn - .forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?; - let mut x = (x + attn)?; - if let Some((attn, ln)) = &mut self.cross_attn { - x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?; - } - let mlp = self.mlp_linear2.forward( - &self - .mlp_linear1 - .forward(&self.mlp_ln.forward(&x)?)? - .gelu()?, - )?; - x + mlp - } -} - -fn sinusoids(length: usize, channels: usize) -> Result<Tensor> { - let max_timescale = 10000f32; - let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32; - let inv_timescales: Vec<_> = (0..channels / 2) - .map(|i| (i as f32 * (-log_timescale_increment)).exp()) - .collect(); - let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?; - let arange = Tensor::arange(0, length as u32, &Device::Cpu)? - .to_dtype(candle::DType::F32)? - .unsqueeze(1)?; - let sh = (length, channels / 2); - let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?; - let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?; - Ok(sincos) -} - -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143 -pub struct AudioEncoder { - conv1: Conv1d, - conv2: Conv1d, - positional_embedding: Tensor, - blocks: Vec<ResidualAttentionBlock>, - ln_post: LayerNorm, - span: tracing::Span, - conv1_span: tracing::Span, - conv2_span: tracing::Span, -} - -impl AudioEncoder { - fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let span = tracing::span!(tracing::Level::TRACE, "audio-encoder"); - let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1"); - let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2"); - let n_state = cfg.d_model; - let n_head = cfg.encoder_attention_heads; - let n_ctx = cfg.max_source_positions; - let cfg1 = Conv1dConfig { - padding: 1, - stride: 1, - groups: 1, - dilation: 1, - }; - let cfg2 = Conv1dConfig { - padding: 1, - stride: 2, - groups: 1, - dilation: 1, - }; - let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?; - let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?; - let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?; - let blocks = (0..cfg.encoder_layers) - .map(|i| { - ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}"))) - }) - .collect::<Result<Vec<_>>>()?; - let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?; - Ok(Self { - conv1, - conv2, - positional_embedding, - blocks, - ln_post, - conv1_span, - conv2_span, - span, - }) - } - - pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> { - let _enter = self.span.enter(); - let x = { - let _enter = self.conv1_span.enter(); - self.conv1.forward(x)?.gelu()? - }; - let x = { - let _enter = self.conv2_span.enter(); - self.conv2.forward(&x)?.gelu()? - }; - let x = x.transpose(1, 2)?; - let (_bsize, seq_len, _hidden) = x.dims3()?; - let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?; - let mut x = x.broadcast_add(&positional_embedding)?; - for block in self.blocks.iter_mut() { - x = block.forward(&x, None, None, flush_kv_cache)? - } - let x = self.ln_post.forward(&x)?; - Ok(x) - } -} - -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176 -pub struct TextDecoder { - token_embedding: Embedding, - positional_embedding: Tensor, - blocks: Vec<ResidualAttentionBlock>, - ln: LayerNorm, - mask: Tensor, - span: tracing::Span, - span_final: tracing::Span, -} - -impl TextDecoder { - fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let span = tracing::span!(tracing::Level::TRACE, "text-decoder"); - let span_final = tracing::span!(tracing::Level::TRACE, "text-decoder-final"); - let n_state = cfg.d_model; - let n_head = cfg.decoder_attention_heads; - let n_ctx = cfg.max_target_positions; - let token_embedding = embedding(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?; - let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?; - let blocks = (0..cfg.decoder_layers) - .map(|i| { - ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}"))) - }) - .collect::<Result<Vec<_>>>()?; - let ln = layer_norm(n_state, vb.pp("layer_norm"))?; - let mask: Vec<_> = (0..n_ctx) - .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) - .collect(); - let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?; - Ok(Self { - token_embedding, - positional_embedding, - blocks, - ln, - mask, - span, - span_final, - }) - } - - pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> { - let _enter = self.span.enter(); - let last = x.dim(D::Minus1)?; - let token_embedding = self.token_embedding.forward(x)?; - let positional_embedding = self.positional_embedding.narrow(0, 0, last)?; - let mut x = token_embedding.broadcast_add(&positional_embedding)?; - for block in self.blocks.iter_mut() { - x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?; - } - self.ln.forward(&x) - } - - pub fn final_linear(&self, x: &Tensor) -> Result<Tensor> { - let b_size = x.dim(0)?; - let w = self.token_embedding.embeddings().broadcast_left(b_size)?; - let logits = { - let _enter = self.span_final.enter(); - x.matmul(&w.t()?)? - }; - Ok(logits) - } -} - -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221 -pub struct Whisper { - pub encoder: AudioEncoder, - pub decoder: TextDecoder, - pub config: Config, -} - -impl Whisper { - pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> { - let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?; - let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?; - Ok(Self { - encoder, - decoder, - config, - }) - } -} diff --git a/candle-examples/examples/whisper/multilingual.rs b/candle-examples/examples/whisper/multilingual.rs index bc0bae1f..a82b09ef 100644 --- a/candle-examples/examples/whisper/multilingual.rs +++ b/candle-examples/examples/whisper/multilingual.rs @@ -113,7 +113,7 @@ pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor) .iter() .map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>"))) .collect::<Result<Vec<_>>>()?; - let sot_token = crate::token_id(tokenizer, crate::SOT_TOKEN)?; + let sot_token = crate::token_id(tokenizer, crate::m::SOT_TOKEN)?; let audio_features = model.encoder.forward(&mel, true)?; let tokens = Tensor::new(&[[sot_token]], device)?; let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?; |