diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-10 09:40:27 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-10 09:40:27 +0100 |
commit | d3f05eae8c4f2df186b46e433be101ac39fceca5 (patch) | |
tree | 6ffc43595caec3007fe28efd3bafc7acbdde6e94 /candle-transformers | |
parent | 258ac32c3868d4103e90df19af99a3e13c805c4e (diff) | |
download | candle-d3f05eae8c4f2df186b46e433be101ac39fceca5.tar.gz candle-d3f05eae8c4f2df186b46e433be101ac39fceca5.tar.bz2 candle-d3f05eae8c4f2df186b46e433be101ac39fceca5.zip |
Move some models to candle-transformers so that it's easier to re-use. (#794)
* Move some models to candle-transformers so that they can be shared.
* Also move falcon.
* Move Llama.
* Move whisper (partial).
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/Cargo.toml | 4 | ||||
-rw-r--r-- | candle-transformers/src/models/bert.rs | 568 | ||||
-rw-r--r-- | candle-transformers/src/models/bigcode.rs | 359 | ||||
-rw-r--r-- | candle-transformers/src/models/falcon.rs | 484 | ||||
-rw-r--r-- | candle-transformers/src/models/llama.rs | 446 | ||||
-rw-r--r-- | candle-transformers/src/models/mod.rs | 6 | ||||
-rw-r--r-- | candle-transformers/src/models/whisper/audio.rs | 210 | ||||
-rw-r--r-- | candle-transformers/src/models/whisper/mod.rs | 26 | ||||
-rw-r--r-- | candle-transformers/src/models/whisper/model.rs | 416 |
9 files changed, 2518 insertions, 1 deletions
diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index a05b9bb7..6b2087cb 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -14,7 +14,11 @@ accelerate-src = { workspace = true, optional = true } candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" } candle-nn = { path = "../candle-nn", version = "0.2.1" } intel-mkl-src = { workspace = true, optional = true } +num-traits = { workspace = true } rand = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tracing = { workspace = true } wav = { workspace = true } [features] diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs new file mode 100644 index 00000000..3f164a3a --- /dev/null +++ b/candle-transformers/src/models/bert.rs @@ -0,0 +1,568 @@ +use candle::{DType, Device, Result, Tensor}; +use candle_nn::{Embedding, Module, VarBuilder}; +use serde::Deserialize; + +pub const DTYPE: DType = DType::F32; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "lowercase")] +enum HiddenAct { + Gelu, + Relu, +} + +struct HiddenActLayer { + act: HiddenAct, + span: tracing::Span, +} + +impl HiddenActLayer { + fn new(act: HiddenAct) -> Self { + let span = tracing::span!(tracing::Level::TRACE, "hidden-act"); + Self { act, span } + } + + fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> { + let _enter = self.span.enter(); + match self.act { + // TODO: The all-MiniLM-L6-v2 model uses "gelu" whereas this is "gelu_new", this explains some + // small numerical difference. + // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213 + HiddenAct::Gelu => xs.gelu(), + HiddenAct::Relu => xs.relu(), + } + } +} + +#[derive(Debug)] +pub struct Linear { + weight: Tensor, + bias: Option<Tensor>, + span: tracing::Span, +} + +impl Linear { + pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self { + let span = tracing::span!(tracing::Level::TRACE, "linear"); + Self { weight, bias, span } + } + + pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> { + let _enter = self.span.enter(); + let w = match x.dims() { + &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, + _ => self.weight.t()?, + }; + let x = x.matmul(&w)?; + match &self.bias { + None => Ok(x), + Some(bias) => x.broadcast_add(bias), + } + } +} + +#[derive(Debug)] +pub struct LayerNorm { + weight: Tensor, + bias: Tensor, + eps: f64, + span: tracing::Span, +} + +impl LayerNorm { + pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self { + let span = tracing::span!(tracing::Level::TRACE, "layer-norm"); + Self { + weight, + bias, + eps, + span, + } + } + + pub fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let x_dtype = x.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let (_bsize, _seq_len, hidden_size) = x.dims3()?; + let x = x.to_dtype(internal_dtype)?; + let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?; + let x = x.broadcast_sub(&mean_x)?; + let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; + let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; + let x = x_normed + .to_dtype(x_dtype)? + .broadcast_mul(&self.weight)? + .broadcast_add(&self.bias)?; + Ok(x) + } +} +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +enum PositionEmbeddingType { + #[default] + Absolute, +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1 +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + vocab_size: usize, + hidden_size: usize, + num_hidden_layers: usize, + num_attention_heads: usize, + intermediate_size: usize, + hidden_act: HiddenAct, + hidden_dropout_prob: f64, + max_position_embeddings: usize, + type_vocab_size: usize, + initializer_range: f64, + layer_norm_eps: f64, + pad_token_id: usize, + #[serde(default)] + position_embedding_type: PositionEmbeddingType, + #[serde(default)] + use_cache: bool, + classifier_dropout: Option<f64>, + model_type: Option<String>, +} + +impl Default for Config { + fn default() -> Self { + Self { + vocab_size: 30522, + hidden_size: 768, + num_hidden_layers: 12, + num_attention_heads: 12, + intermediate_size: 3072, + hidden_act: HiddenAct::Gelu, + hidden_dropout_prob: 0.1, + max_position_embeddings: 512, + type_vocab_size: 2, + initializer_range: 0.02, + layer_norm_eps: 1e-12, + pad_token_id: 0, + position_embedding_type: PositionEmbeddingType::Absolute, + use_cache: true, + classifier_dropout: None, + model_type: Some("bert".to_string()), + } + } +} + +impl Config { + fn _all_mini_lm_l6_v2() -> Self { + // https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/blob/main/config.json + Self { + vocab_size: 30522, + hidden_size: 384, + num_hidden_layers: 6, + num_attention_heads: 12, + intermediate_size: 1536, + hidden_act: HiddenAct::Gelu, + hidden_dropout_prob: 0.1, + max_position_embeddings: 512, + type_vocab_size: 2, + initializer_range: 0.02, + layer_norm_eps: 1e-12, + pad_token_id: 0, + position_embedding_type: PositionEmbeddingType::Absolute, + use_cache: true, + classifier_dropout: None, + model_type: Some("bert".to_string()), + } + } +} + +fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { + let embeddings = vb.get((vocab_size, hidden_size), "weight")?; + Ok(Embedding::new(embeddings, hidden_size)) +} + +fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { + let weight = vb.get((size2, size1), "weight")?; + let bias = vb.get(size2, "bias")?; + Ok(Linear::new(weight, Some(bias))) +} + +struct Dropout { + #[allow(dead_code)] + pr: f64, +} + +impl Dropout { + fn new(pr: f64) -> Self { + Self { pr } + } + + fn forward(&self, x: &Tensor) -> Result<Tensor> { + // TODO + Ok(x.clone()) + } +} + +fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> { + let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) { + (Ok(weight), Ok(bias)) => (weight, bias), + (Err(err), _) | (_, Err(err)) => { + if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) { + (weight, bias) + } else { + return Err(err); + } + } + }; + Ok(LayerNorm::new(weight, bias, eps)) +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180 +struct BertEmbeddings { + word_embeddings: Embedding, + position_embeddings: Option<Embedding>, + token_type_embeddings: Embedding, + layer_norm: LayerNorm, + dropout: Dropout, + span: tracing::Span, +} + +impl BertEmbeddings { + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let word_embeddings = embedding( + config.vocab_size, + config.hidden_size, + vb.pp("word_embeddings"), + )?; + let position_embeddings = embedding( + config.max_position_embeddings, + config.hidden_size, + vb.pp("position_embeddings"), + )?; + let token_type_embeddings = embedding( + config.type_vocab_size, + config.hidden_size, + vb.pp("token_type_embeddings"), + )?; + let layer_norm = layer_norm( + config.hidden_size, + config.layer_norm_eps, + vb.pp("LayerNorm"), + )?; + Ok(Self { + word_embeddings, + position_embeddings: Some(position_embeddings), + token_type_embeddings, + layer_norm, + dropout: Dropout::new(config.hidden_dropout_prob), + span: tracing::span!(tracing::Level::TRACE, "embeddings"), + }) + } + + fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let (_bsize, seq_len) = input_ids.dims2()?; + let input_embeddings = self.word_embeddings.forward(input_ids)?; + let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?; + let mut embeddings = (&input_embeddings + token_type_embeddings)?; + if let Some(position_embeddings) = &self.position_embeddings { + // TODO: Proper absolute positions? + let position_ids = (0..seq_len as u32).collect::<Vec<_>>(); + let position_ids = Tensor::new(&position_ids[..], input_ids.device())?; + embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)? + } + let embeddings = self.layer_norm.forward(&embeddings)?; + let embeddings = self.dropout.forward(&embeddings)?; + Ok(embeddings) + } +} + +struct BertSelfAttention { + query: Linear, + key: Linear, + value: Linear, + dropout: Dropout, + num_attention_heads: usize, + attention_head_size: usize, + span: tracing::Span, + span_softmax: tracing::Span, +} + +impl BertSelfAttention { + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let attention_head_size = config.hidden_size / config.num_attention_heads; + let all_head_size = config.num_attention_heads * attention_head_size; + let dropout = Dropout::new(config.hidden_dropout_prob); + let hidden_size = config.hidden_size; + let query = linear(hidden_size, all_head_size, vb.pp("query"))?; + let value = linear(hidden_size, all_head_size, vb.pp("value"))?; + let key = linear(hidden_size, all_head_size, vb.pp("key"))?; + Ok(Self { + query, + key, + value, + dropout, + num_attention_heads: config.num_attention_heads, + attention_head_size, + span: tracing::span!(tracing::Level::TRACE, "self-attn"), + span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"), + }) + } + + fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> { + let mut new_x_shape = xs.dims().to_vec(); + new_x_shape.pop(); + new_x_shape.push(self.num_attention_heads); + new_x_shape.push(self.attention_head_size); + let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?; + xs.contiguous() + } + + fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let query_layer = self.query.forward(hidden_states)?; + let key_layer = self.key.forward(hidden_states)?; + let value_layer = self.value.forward(hidden_states)?; + + let query_layer = self.transpose_for_scores(&query_layer)?; + let key_layer = self.transpose_for_scores(&key_layer)?; + let value_layer = self.transpose_for_scores(&value_layer)?; + + let attention_scores = query_layer.matmul(&key_layer.t()?)?; + let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?; + let attention_probs = { + let _enter_sm = self.span_softmax.enter(); + candle_nn::ops::softmax(&attention_scores, candle::D::Minus1)? + }; + let attention_probs = self.dropout.forward(&attention_probs)?; + + let context_layer = attention_probs.matmul(&value_layer)?; + let context_layer = context_layer.transpose(1, 2)?.contiguous()?; + let context_layer = context_layer.flatten_from(candle::D::Minus2)?; + Ok(context_layer) + } +} + +struct BertSelfOutput { + dense: Linear, + layer_norm: LayerNorm, + dropout: Dropout, + span: tracing::Span, +} + +impl BertSelfOutput { + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?; + let layer_norm = layer_norm( + config.hidden_size, + config.layer_norm_eps, + vb.pp("LayerNorm"), + )?; + let dropout = Dropout::new(config.hidden_dropout_prob); + Ok(Self { + dense, + layer_norm, + dropout, + span: tracing::span!(tracing::Level::TRACE, "self-out"), + }) + } + + fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let hidden_states = self.dense.forward(hidden_states)?; + let hidden_states = self.dropout.forward(&hidden_states)?; + self.layer_norm.forward(&(hidden_states + input_tensor)?) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392 +struct BertAttention { + self_attention: BertSelfAttention, + self_output: BertSelfOutput, + span: tracing::Span, +} + +impl BertAttention { + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let self_attention = BertSelfAttention::load(vb.pp("self"), config)?; + let self_output = BertSelfOutput::load(vb.pp("output"), config)?; + Ok(Self { + self_attention, + self_output, + span: tracing::span!(tracing::Level::TRACE, "attn"), + }) + } + + fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let self_outputs = self.self_attention.forward(hidden_states)?; + let attention_output = self.self_output.forward(&self_outputs, hidden_states)?; + Ok(attention_output) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441 +struct BertIntermediate { + dense: Linear, + intermediate_act: HiddenActLayer, + span: tracing::Span, +} + +impl BertIntermediate { + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?; + Ok(Self { + dense, + intermediate_act: HiddenActLayer::new(config.hidden_act), + span: tracing::span!(tracing::Level::TRACE, "inter"), + }) + } + + fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let hidden_states = self.dense.forward(hidden_states)?; + let ys = self.intermediate_act.forward(&hidden_states)?; + Ok(ys) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456 +struct BertOutput { + dense: Linear, + layer_norm: LayerNorm, + dropout: Dropout, + span: tracing::Span, +} + +impl BertOutput { + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?; + let layer_norm = layer_norm( + config.hidden_size, + config.layer_norm_eps, + vb.pp("LayerNorm"), + )?; + let dropout = Dropout::new(config.hidden_dropout_prob); + Ok(Self { + dense, + layer_norm, + dropout, + span: tracing::span!(tracing::Level::TRACE, "out"), + }) + } + + fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let hidden_states = self.dense.forward(hidden_states)?; + let hidden_states = self.dropout.forward(&hidden_states)?; + self.layer_norm.forward(&(hidden_states + input_tensor)?) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470 +struct BertLayer { + attention: BertAttention, + intermediate: BertIntermediate, + output: BertOutput, + span: tracing::Span, +} + +impl BertLayer { + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let attention = BertAttention::load(vb.pp("attention"), config)?; + let intermediate = BertIntermediate::load(vb.pp("intermediate"), config)?; + let output = BertOutput::load(vb.pp("output"), config)?; + Ok(Self { + attention, + intermediate, + output, + span: tracing::span!(tracing::Level::TRACE, "layer"), + }) + } + + fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let attention_output = self.attention.forward(hidden_states)?; + // TODO: Support cross-attention? + // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523 + // TODO: Support something similar to `apply_chunking_to_forward`? + let intermediate_output = self.intermediate.forward(&attention_output)?; + let layer_output = self + .output + .forward(&intermediate_output, &attention_output)?; + Ok(layer_output) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556 +struct BertEncoder { + layers: Vec<BertLayer>, + span: tracing::Span, +} + +impl BertEncoder { + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let layers = (0..config.num_hidden_layers) + .map(|index| BertLayer::load(vb.pp(&format!("layer.{index}")), config)) + .collect::<Result<Vec<_>>>()?; + let span = tracing::span!(tracing::Level::TRACE, "encoder"); + Ok(BertEncoder { layers, span }) + } + + fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let mut hidden_states = hidden_states.clone(); + // Use a loop rather than a fold as it's easier to modify when adding debug/... + for layer in self.layers.iter() { + hidden_states = layer.forward(&hidden_states)? + } + Ok(hidden_states) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874 +pub struct BertModel { + embeddings: BertEmbeddings, + encoder: BertEncoder, + pub device: Device, + span: tracing::Span, +} + +impl BertModel { + pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let (embeddings, encoder) = match ( + BertEmbeddings::load(vb.pp("embeddings"), config), + BertEncoder::load(vb.pp("encoder"), config), + ) { + (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), + (Err(err), _) | (_, Err(err)) => { + if let Some(model_type) = &config.model_type { + if let (Ok(embeddings), Ok(encoder)) = ( + BertEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config), + BertEncoder::load(vb.pp(&format!("{model_type}.encoder")), config), + ) { + (embeddings, encoder) + } else { + return Err(err); + } + } else { + return Err(err); + } + } + }; + Ok(Self { + embeddings, + encoder, + device: vb.device().clone(), + span: tracing::span!(tracing::Level::TRACE, "model"), + }) + } + + pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?; + let sequence_output = self.encoder.forward(&embedding_output)?; + Ok(sequence_output) + } +} diff --git a/candle-transformers/src/models/bigcode.rs b/candle-transformers/src/models/bigcode.rs new file mode 100644 index 00000000..1e63956b --- /dev/null +++ b/candle-transformers/src/models/bigcode.rs @@ -0,0 +1,359 @@ +use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder}; + +fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> { + let weight = vb.get((size2, size1), "weight")?; + let bias = if bias { + Some(vb.get(size2, "bias")?) + } else { + None + }; + Ok(Linear::new(weight, bias)) +} + +fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { + let embeddings = vb.get((vocab_size, hidden_size), "weight")?; + Ok(Embedding::new(embeddings, hidden_size)) +} + +fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> { + let weight = vb.get(size, "weight")?; + let bias = vb.get(size, "bias")?; + Ok(LayerNorm::new(weight, bias, eps)) +} + +fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> { + let mask: Vec<_> = (0..t) + .flat_map(|i| (0..t).map(move |j| u8::from(j <= i))) + .collect(); + let mask = Tensor::from_slice(&mask, (t, t), device)?; + Ok(mask) +} + +#[derive(Debug)] +pub struct Config { + pub vocab_size: usize, + // max_position_embeddings aka n_positions + pub max_position_embeddings: usize, + // num_hidden_layers aka n_layer + pub num_hidden_layers: usize, + // hidden_size aka n_embd + pub hidden_size: usize, + pub layer_norm_epsilon: f64, + pub n_inner: Option<usize>, + // num_attention_heads aka n_head + pub num_attention_heads: usize, + pub multi_query: bool, + pub use_cache: bool, +} + +impl Config { + #[allow(dead_code)] + pub fn starcoder_1b() -> Self { + Self { + vocab_size: 49152, + max_position_embeddings: 8192, + num_hidden_layers: 24, + hidden_size: 2048, + layer_norm_epsilon: 1e-5, + n_inner: Some(8192), + num_attention_heads: 16, + multi_query: true, + use_cache: true, + } + } + + #[allow(dead_code)] + pub fn starcoder_3b() -> Self { + Self { + vocab_size: 49152, + max_position_embeddings: 8192, + num_hidden_layers: 36, + hidden_size: 2816, + layer_norm_epsilon: 1e-5, + n_inner: Some(11264), + num_attention_heads: 22, + multi_query: true, + use_cache: true, + } + } + + #[allow(dead_code)] + pub fn starcoder_7b() -> Self { + Self { + vocab_size: 49152, + max_position_embeddings: 8192, + num_hidden_layers: 42, + hidden_size: 4096, + layer_norm_epsilon: 1e-5, + n_inner: Some(16384), + num_attention_heads: 32, + multi_query: true, + use_cache: true, + } + } + + #[allow(dead_code)] + pub fn starcoder() -> Self { + Self { + vocab_size: 49152, + max_position_embeddings: 8192, + num_hidden_layers: 40, + hidden_size: 6144, + layer_norm_epsilon: 1e-5, + n_inner: Some(24576), + num_attention_heads: 48, + multi_query: true, + use_cache: true, + } + } +} + +struct Attention { + c_attn: Linear, + c_proj: Linear, + kv_cache: Option<Tensor>, + use_cache: bool, + embed_dim: usize, + kv_dim: usize, + num_heads: usize, + head_dim: usize, + multi_query: bool, +} + +impl Attention { + pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let hidden_size = cfg.hidden_size; + let head_dim = hidden_size / cfg.num_attention_heads; + let kv_heads = if cfg.multi_query { + 1 + } else { + cfg.num_attention_heads + }; + let kv_dim = kv_heads * head_dim; + let c_attn = linear(hidden_size, hidden_size + 2 * kv_dim, true, vb.pp("c_attn"))?; + let c_proj = linear(hidden_size, hidden_size, true, vb.pp("c_proj"))?; + Ok(Self { + c_proj, + c_attn, + embed_dim: hidden_size, + kv_cache: None, + use_cache: cfg.use_cache, + kv_dim, + head_dim, + num_heads: cfg.num_attention_heads, + multi_query: cfg.multi_query, + }) + } + + fn attn( + &self, + query: &Tensor, + key: &Tensor, + value: &Tensor, + attention_mask: &Tensor, + ) -> Result<Tensor> { + if query.dtype() != DType::F32 { + // If we start supporting f16 models, we may need the upcasting scaling bits. + // https://github.com/huggingface/transformers/blob/a0042379269bea9182c1f87e6b2eee4ba4c8cce8/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L133 + candle::bail!("upcasting is not supported {:?}", query.dtype()) + } + let scale_factor = 1f64 / (self.head_dim as f64).sqrt(); + let initial_query_shape = query.shape(); + let key_len = key.dim(D::Minus1)?; + let (query, key, attn_shape, attn_view) = if self.multi_query { + let (b_sz, query_len, _) = query.dims3()?; + let query = query.reshape((b_sz, query_len * self.num_heads, self.head_dim))?; + let attn_shape = (b_sz, query_len, self.num_heads, key_len); + let attn_view = (b_sz, query_len * self.num_heads, key_len); + (query, key.clone(), attn_shape, attn_view) + } else { + let (b_sz, _num_heads, query_len, _head_dim) = query.dims4()?; + let query = query.reshape((b_sz, query_len * self.num_heads, self.head_dim))?; + let key = key.reshape((b_sz * self.num_heads, self.head_dim, key_len))?; + let attn_shape = (b_sz, self.num_heads, query_len, key_len); + let attn_view = (b_sz * self.num_heads, query_len, key_len); + (query, key, attn_shape, attn_view) + }; + + let attn_weights = + (query.matmul(&key.contiguous()?)? * scale_factor)?.reshape(attn_shape)?; + let attention_mask = attention_mask.broadcast_as(attn_shape)?; + let mask_value = + Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?; + let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?; + let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; + let value = value.contiguous()?; + let attn_output = if self.multi_query { + attn_weights + .reshape(attn_view)? + .matmul(&value)? + .reshape(initial_query_shape)? + } else { + attn_weights.matmul(&value)? + }; + Ok(attn_output) + } + + fn forward(&mut self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> { + let qkv = self.c_attn.forward(hidden_states)?; + let (query, key_value) = if self.multi_query { + let query = qkv.i((.., .., ..self.embed_dim))?; + let key_value = qkv.i((.., .., self.embed_dim..self.embed_dim + 2 * self.kv_dim))?; + (query, key_value) + } else { + let mut dims = qkv.dims().to_vec(); + dims.pop(); + dims.push(self.embed_dim); + dims.push(self.head_dim * 3); + let qkv = qkv.reshape(dims)?.transpose(1, 2)?; + let query = qkv.i((.., .., .., ..self.head_dim))?; + let key_value = qkv.i((.., .., .., self.head_dim..3 * self.head_dim))?; + (query, key_value) + }; + let mut key_value = key_value; + if self.use_cache { + if let Some(kv_cache) = &self.kv_cache { + // TODO: we could trim the tensors to MAX_SEQ_LEN so that this would work for + // arbitrarily large sizes. + key_value = Tensor::cat(&[kv_cache, &key_value], D::Minus2)?.contiguous()?; + } + self.kv_cache = Some(key_value.clone()) + } + + let key = key_value.narrow(D::Minus1, 0, self.head_dim)?; + let value = key_value.narrow(D::Minus1, self.head_dim, self.head_dim)?; + let attn_output = self.attn(&query, &key.t()?, &value, attention_mask)?; + let attn_output = if self.multi_query { + attn_output + } else { + attn_output + .transpose(1, 2)? + .reshape(hidden_states.shape())? + }; + let attn_output = self.c_proj.forward(&attn_output)?; + Ok(attn_output) + } +} + +struct Mlp { + c_fc: Linear, + c_proj: Linear, +} + +impl Mlp { + fn load(inner_dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> { + let c_fc = linear(cfg.hidden_size, inner_dim, true, vb.pp("c_fc"))?; + let c_proj = linear(inner_dim, cfg.hidden_size, true, vb.pp("c_proj"))?; + Ok(Self { c_fc, c_proj }) + } + + fn forward(&mut self, hidden_states: &Tensor) -> Result<Tensor> { + let hidden_states = self.c_fc.forward(hidden_states)?.gelu()?; + let hidden_states = self.c_proj.forward(&hidden_states)?; + Ok(hidden_states) + } +} + +// TODO: Add cross-attention? +struct Block { + ln_1: LayerNorm, + attn: Attention, + ln_2: LayerNorm, + mlp: Mlp, +} + +impl Block { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let hidden_size = cfg.hidden_size; + let inner_dim = cfg.n_inner.unwrap_or(4 * hidden_size); + let ln_1 = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb.pp("ln_1"))?; + let attn = Attention::load(vb.pp("attn"), cfg)?; + let ln_2 = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb.pp("ln_2"))?; + let mlp = Mlp::load(inner_dim, vb.pp("mlp"), cfg)?; + Ok(Self { + ln_1, + attn, + ln_2, + mlp, + }) + } + + fn forward(&mut self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> { + let residual = hidden_states; + let hidden_states = self.ln_1.forward(hidden_states)?; + let attn_outputs = self.attn.forward(&hidden_states, attention_mask)?; + let hidden_states = (&attn_outputs + residual)?; + let residual = &hidden_states; + let hidden_states = self.ln_2.forward(&hidden_states)?; + let hidden_states = self.mlp.forward(&hidden_states)?; + let hidden_states = (&hidden_states + residual)?; + Ok(hidden_states) + } +} + +pub struct GPTBigCode { + wte: Embedding, + wpe: Embedding, + blocks: Vec<Block>, + ln_f: LayerNorm, + lm_head: Linear, + bias: Tensor, + config: Config, +} + +impl GPTBigCode { + pub fn config(&self) -> &Config { + &self.config + } + + pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> { + let hidden_size = cfg.hidden_size; + let vb_t = vb.pp("transformer"); + let wte = embedding(cfg.vocab_size, hidden_size, vb_t.pp("wte"))?; + let wpe = embedding(cfg.max_position_embeddings, hidden_size, vb_t.pp("wpe"))?; + let blocks = (0..cfg.num_hidden_layers) + .map(|i| Block::load(vb_t.pp(&format!("h.{i}")), &cfg)) + .collect::<Result<Vec<_>>>()?; + let ln_f = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb_t.pp("ln_f"))?; + let lm_head = linear(hidden_size, cfg.vocab_size, false, vb_t.pp("wte"))?; + let bias = make_causal_mask(cfg.max_position_embeddings, vb.device())?; + Ok(Self { + wte, + wpe, + blocks, + lm_head, + ln_f, + bias, + config: cfg, + }) + } + + pub fn forward(&mut self, input_ids: &Tensor, past_len: usize) -> Result<Tensor> { + let dev = input_ids.device(); + let (b_sz, seq_len) = input_ids.dims2()?; + + let key_len = past_len + seq_len; + let attention_mask = self.bias.i((past_len..key_len, ..key_len))?.unsqueeze(0)?; + // MQA models: (batch_size, query_length, n_heads, key_length) + // MHA models: (batch_size, n_heads, query_length, key_length) + let seq_len_dim = if self.config.multi_query { 2 } else { 1 }; + let attention_mask = attention_mask.unsqueeze(seq_len_dim)?; + + let position_ids = Tensor::arange(past_len as u32, (past_len + seq_len) as u32, dev)?; + let position_ids = position_ids.unsqueeze(0)?.broadcast_as((b_sz, seq_len))?; + let input_embeds = self.wte.forward(input_ids)?; + let position_embeds = self.wpe.forward(&position_ids)?; + + let mut hidden_states = (&input_embeds + &position_embeds)?; + for block in self.blocks.iter_mut() { + hidden_states = block.forward(&hidden_states, &attention_mask)?; + } + let hidden_states = self.ln_f.forward(&hidden_states)?; + let hidden_states = hidden_states + .reshape((b_sz, seq_len, self.config.hidden_size))? + .narrow(1, seq_len - 1, 1)?; + let logits = self.lm_head.forward(&hidden_states)?.squeeze(1)?; + Ok(logits) + } +} diff --git a/candle-transformers/src/models/falcon.rs b/candle-transformers/src/models/falcon.rs new file mode 100644 index 00000000..6ede136a --- /dev/null +++ b/candle-transformers/src/models/falcon.rs @@ -0,0 +1,484 @@ +use candle::{DType, Device, Result, Tensor, D}; +use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder}; + +const MAX_SEQ_LEN: usize = 5000; + +fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> { + let weight = vb.get((size2, size1), "weight")?; + let bias = if bias { + Some(vb.get(size2, "bias")?) + } else { + None + }; + Ok(Linear::new(weight, bias)) +} + +fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> { + let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) { + (Ok(weight), Ok(bias)) => (weight, bias), + (Err(err), _) | (_, Err(err)) => { + if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) { + (weight, bias) + } else { + return Err(err); + } + } + }; + Ok(LayerNorm::new(weight, bias, eps)) +} + +fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { + let embeddings = vb.get((vocab_size, hidden_size), "weight")?; + Ok(Embedding::new(embeddings, hidden_size)) +} + +// https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py +#[derive(Debug)] +pub struct Config { + pub vocab_size: usize, + pub hidden_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub layer_norm_epsilon: f64, + pub initializer_range: f64, + pub use_cache: bool, + pub bos_token_id: u32, + pub eos_token_id: u32, + pub hidden_dropout: f64, + pub attention_dropout: f64, + pub n_head_kv: Option<usize>, + pub alibi: bool, + pub new_decoder_architecture: bool, + pub multi_query: bool, + pub parallel_attn: bool, + pub bias: bool, +} + +impl Default for Config { + fn default() -> Self { + Self { + vocab_size: 65024, + hidden_size: 4544, + num_hidden_layers: 32, + num_attention_heads: 71, + layer_norm_epsilon: 1e-5, + initializer_range: 0.02, + use_cache: true, + bos_token_id: 11, + eos_token_id: 11, + hidden_dropout: 0.0, + attention_dropout: 0.0, + n_head_kv: None, + alibi: false, + new_decoder_architecture: false, + multi_query: true, + parallel_attn: true, + bias: false, + } + } +} + +impl Config { + pub fn validate(&self) -> Result<()> { + if self.alibi { + candle::bail!("alibi is not supported"); + } + if self.new_decoder_architecture { + candle::bail!("new_decoder_architecture is not supported"); + } + if self.n_head_kv.is_some() { + candle::bail!("n_head_kv is not supported"); + } + Ok(()) + } + + // https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json + pub fn falcon7b() -> Self { + // This is currently on par with the defaults, the defaults come from the Python default + // arguments for the config initialization whereas the following come from the json config. + Self { + vocab_size: 65024, + hidden_size: 4544, + num_hidden_layers: 32, + num_attention_heads: 71, + layer_norm_epsilon: 1e-5, + initializer_range: 0.02, + use_cache: true, + bos_token_id: 11, + eos_token_id: 11, + hidden_dropout: 0., + attention_dropout: 0., + n_head_kv: None, + alibi: false, + new_decoder_architecture: false, + multi_query: true, + parallel_attn: true, + bias: false, + } + } + + fn head_dim(&self) -> usize { + self.hidden_size / self.num_attention_heads + } + + fn rotary(&self) -> bool { + !self.alibi + } +} + +fn rotate_half(x: &Tensor) -> Result<Tensor> { + let l = x.dim(D::Minus1)?; + let x1 = x.narrow(D::Minus1, 0, l / 2)?; + let x2 = x.narrow(D::Minus1, l / 2, l - l / 2)?; + let x21 = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?; + Ok(x21) +} + +#[derive(Debug)] +struct FalconRotaryEmbedding { + inv_freq: Tensor, + cache: Option<(usize, Tensor, Tensor)>, +} + +impl FalconRotaryEmbedding { + fn load(device: &Device, cfg: &Config) -> Result<Self> { + let head_dim = cfg.head_dim(); + let inv_freq: Vec<_> = (0..head_dim) + .step_by(2) + .map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32)) + .collect(); + Ok(Self { + inv_freq: Tensor::new(inv_freq.as_slice(), device)?, + cache: None, + }) + } + + fn cos_sin( + &mut self, + seq_len: usize, + device: &Device, + dtype: DType, + ) -> Result<(Tensor, Tensor)> { + match &self.cache { + Some((s, cos, sin)) if *s == seq_len => { + return Ok((cos.clone(), sin.clone())); + } + _ => {} + } + let t = Tensor::arange(0, seq_len as u32, device)?.to_dtype(dtype)?; + let inv_freq = self.inv_freq.to_dtype(dtype)?; + let freqs = t.unsqueeze(1)?.matmul(&inv_freq.unsqueeze(0)?)?; + let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; + let cos = emb.cos()?; + let sin = emb.sin()?; + self.cache = Some((seq_len, cos.clone(), sin.clone())); + Ok((cos, sin)) + } + + fn forward( + &mut self, + query: &Tensor, + key: &Tensor, + past_kv_len: usize, + ) -> Result<(Tensor, Tensor)> { + let (_batch, seq_len, _head_dim) = query.dims3()?; + let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, query.device(), query.dtype())?; + let cos = cos.narrow(0, past_kv_len, seq_len)?; + let sin = sin.narrow(0, past_kv_len, seq_len)?; + let qs = (query.broadcast_mul(&cos)? + &rotate_half(query)?.broadcast_mul(&sin)?)?; + let ks = (key.broadcast_mul(&cos)? + &rotate_half(key)?.broadcast_mul(&sin)?)?; + Ok((qs, ks)) + } +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + +#[derive(Debug)] +struct FalconAttention { + query_key_value: Linear, + dense: Linear, + maybe_rotary: Option<FalconRotaryEmbedding>, + kv_cache: Option<(Tensor, Tensor)>, + inv_norm_factor: f64, + multi_query: bool, + use_cache: bool, + num_heads: usize, + head_dim: usize, + n_head_kv: usize, +} + +impl FalconAttention { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let maybe_rotary = if cfg.rotary() { + let rotary = FalconRotaryEmbedding::load(vb.device(), cfg)?; + Some(rotary) + } else { + None + }; + let head_dim = cfg.head_dim(); + let hidden_size = cfg.hidden_size; + let qkv_out_dim = if cfg.multi_query { + hidden_size + 2 * head_dim + } else { + 3 * hidden_size + }; + let query_key_value = linear(hidden_size, qkv_out_dim, cfg.bias, vb.pp("query_key_value"))?; + let dense = linear(hidden_size, hidden_size, cfg.bias, vb.pp("dense"))?; + Ok(Self { + query_key_value, + dense, + maybe_rotary, + kv_cache: None, + inv_norm_factor: 1. / (head_dim as f64).sqrt(), + multi_query: cfg.multi_query, + use_cache: cfg.use_cache, + num_heads: cfg.num_attention_heads, + n_head_kv: cfg.n_head_kv.unwrap_or(1), + head_dim, + }) + } + + fn split_heads(&self, fused_qkv: &Tensor) -> Result<(Tensor, Tensor, Tensor)> { + let (b_sz, seq_len, _) = fused_qkv.dims3()?; + if !self.multi_query { + let fused_qkv = fused_qkv.reshape((b_sz, seq_len, self.num_heads, 3, self.head_dim))?; + let q = fused_qkv.narrow(D::Minus2, 0, 1)?.squeeze(D::Minus2)?; + let k = fused_qkv.narrow(D::Minus2, 1, 1)?.squeeze(D::Minus2)?; + let v = fused_qkv.narrow(D::Minus2, 2, 1)?.squeeze(D::Minus2)?; + Ok((q, k, v)) + } else { + let fused_qkv = + fused_qkv.reshape((b_sz, seq_len, self.num_heads + 2, self.head_dim))?; + let d = fused_qkv.dim(D::Minus2)?; + let q = fused_qkv.narrow(D::Minus2, 0, d - 2)?; + let k = fused_qkv.narrow(D::Minus2, d - 2, 1)?; + let v = fused_qkv.narrow(D::Minus2, d - 1, 1)?; + Ok((q, k, v)) + } + } + + fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result<Tensor> { + let fused_qkv = self.query_key_value.forward(x)?; + let head_dim = self.head_dim; + let (query, key, value) = self.split_heads(&fused_qkv)?; + let (b_sz, seq_len, _, _) = query.dims4()?; + let query = query + .transpose(1, 2)? + .reshape((b_sz * self.num_heads, seq_len, head_dim))?; + let key = key + .transpose(1, 2)? + .reshape((b_sz * self.n_head_kv, seq_len, head_dim))?; + let value = value + .transpose(1, 2)? + .reshape((b_sz * self.n_head_kv, seq_len, head_dim))?; + let (query, key) = if let Some(r) = &mut self.maybe_rotary { + r.forward(&query, &key, past_kv_len)? + } else { + (query, key) + }; + let (mut key, mut value) = (key, value); + let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)?.to_dtype(query.dtype())?; + if self.use_cache { + if let Some((cache_k, cache_v)) = &self.kv_cache { + // TODO: we could trim the tensors to MAX_SEQ_LEN so that this would work for + // arbitrarily large sizes. + key = Tensor::cat(&[cache_k, &key], 1)?.contiguous()?; + value = Tensor::cat(&[cache_v, &value], 1)?.contiguous()?; + } + self.kv_cache = Some((key.clone(), value.clone())) + } + let query = query.reshape((b_sz * self.num_heads, seq_len, head_dim))?; + let all_len = past_kv_len + seq_len; + let key = key.reshape((b_sz * self.n_head_kv, all_len, head_dim))?; + let value = value.reshape((b_sz * self.n_head_kv, all_len, head_dim))?; + + let (key, value) = if self.n_head_kv == 1 { + ( + key.broadcast_as((b_sz * self.num_heads, all_len, head_dim))?, + value.broadcast_as((b_sz * self.num_heads, all_len, head_dim))?, + ) + } else { + (key, value) + }; + + // Only handle the case where alibi is None here, and non-flash attention. + let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?; + let attention_scores = candle_nn::ops::softmax( + &attention_scores + .broadcast_add(&mask.squeeze(1)?)? + .to_dtype(DType::F32)?, + D::Minus1, + )? + .to_dtype(x.dtype())?; + let attn_output = attention_scores + .matmul(&value)? + .reshape((b_sz, self.num_heads, seq_len, head_dim))? + .transpose(1, 2)? + .reshape((b_sz, seq_len, self.num_heads * head_dim))?; + let attn_output = self.dense.forward(&attn_output)?; + Ok(attn_output) + } +} + +#[derive(Debug)] +struct FalconMlp { + dense_h_to_4h: Linear, + dense_4h_to_h: Linear, +} + +impl FalconMlp { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let h = cfg.hidden_size; + let b = cfg.bias; + let dense_h_to_4h = linear(h, 4 * h, b, vb.pp("dense_h_to_4h"))?; + let dense_4h_to_h = linear(4 * h, h, b, vb.pp("dense_4h_to_h"))?; + Ok(Self { + dense_h_to_4h, + dense_4h_to_h, + }) + } + + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let x = self.dense_h_to_4h.forward(x)?.gelu()?; + let x = self.dense_4h_to_h.forward(&x)?; + Ok(x) + } +} + +#[derive(Debug)] +struct FalconDecoderLayer { + inp_layernorm: LayerNorm, + self_attention: FalconAttention, + post_attention_layernorm: Option<LayerNorm>, + mlp: FalconMlp, + parallel_attn: bool, +} + +impl FalconDecoderLayer { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let mlp = FalconMlp::load(vb.pp("mlp"), cfg)?; + let inp_layernorm = layer_norm( + cfg.hidden_size, + cfg.layer_norm_epsilon, + vb.pp("input_layernorm"), + )?; + let self_attention = FalconAttention::load(vb.pp("self_attention"), cfg)?; + let post_attention_layernorm = if cfg.parallel_attn { + None + } else { + let ln = layer_norm( + cfg.hidden_size, + cfg.layer_norm_epsilon, + vb.pp("post_attention_layernorm"), + )?; + Some(ln) + }; + Ok(Self { + inp_layernorm, + self_attention, + post_attention_layernorm, + mlp, + parallel_attn: cfg.parallel_attn, + }) + } + + fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result<Tensor> { + let residual = x.clone(); + let ln_attn = self.inp_layernorm.forward(x)?; + let attn_output = self.self_attention.forward(&ln_attn, mask, past_kv_len)?; + let (residual, ln_mlp) = match &self.post_attention_layernorm { + None => (residual, ln_attn), + Some(pal) => { + // This should include some dropout. + let residual = (&attn_output + &residual)?; + let ln_mlp = pal.forward(&residual)?; + (residual, ln_mlp) + } + }; + let mlp_output = self.mlp.forward(&ln_mlp)?; + + let mlp_output = if self.parallel_attn { + (mlp_output + attn_output)? + } else { + mlp_output + }; + let output = (mlp_output + residual)?; + Ok(output) + } +} + +#[derive(Debug)] +pub struct Falcon { + word_embeddings: Embedding, + blocks: Vec<FalconDecoderLayer>, + ln_f: LayerNorm, + lm_head: Linear, + config: Config, +} + +fn make_causal_mask(t: usize) -> Result<Tensor> { + let mask: Vec<_> = (0..t) + .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) + .collect(); + let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?; + Ok(mask) +} + +fn prepare_attn_mask(b_sz: usize, seq_len: usize) -> Result<Tensor> { + // let mask = Tensor::ones((b_sz, seq_len), DType::U32, &Device::Cpu)?; + let mask = make_causal_mask(seq_len)?; + let mask = mask.broadcast_as((b_sz, 1, seq_len, seq_len))?; + Ok(mask) +} + +impl Falcon { + pub fn config(&self) -> &Config { + &self.config + } + + pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> { + let word_embeddings = embedding( + cfg.vocab_size, + cfg.hidden_size, + vb.pp("transformer.word_embeddings"), + )?; + let blocks = (0..cfg.num_hidden_layers) + .map(|i| FalconDecoderLayer::load(vb.pp(&format!("transformer.h.{i}")), &cfg)) + .collect::<Result<Vec<_>>>()?; + let ln_f = layer_norm( + cfg.hidden_size, + cfg.layer_norm_epsilon, + vb.pp("transformer.ln_f"), + )?; + let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))?; + Ok(Self { + word_embeddings, + blocks, + ln_f, + lm_head, + config: cfg, + }) + } + + pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> { + let (b_sz, seq_len) = input_ids.dims2()?; + let mut hidden_state = self.word_embeddings.forward(input_ids)?; + let past_kv_len = match &self.blocks[0].self_attention.kv_cache { + Some((k, _)) => k.dim(1)?, + None => 0, + }; + let causal_mask = prepare_attn_mask(b_sz, seq_len)?.to_device(input_ids.device())?; + for block in self.blocks.iter_mut() { + hidden_state = block.forward(&hidden_state, &causal_mask, past_kv_len)?; + } + let hidden_state = self.ln_f.forward(&hidden_state)?; + let hidden_state = hidden_state.narrow(1, seq_len - 1, 1)?; + let logits = self.lm_head.forward(&hidden_state)?.squeeze(1)?; + Ok(logits) + } +} diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs new file mode 100644 index 00000000..eed4df5e --- /dev/null +++ b/candle-transformers/src/models/llama.rs @@ -0,0 +1,446 @@ +use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn::{Embedding, Module, VarBuilder}; +use serde::Deserialize; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; + +pub const MAX_SEQ_LEN: usize = 4096; + +#[derive(Deserialize)] +pub struct LlamaConfig { + pub hidden_size: usize, + pub intermediate_size: usize, + pub vocab_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: Option<usize>, + pub rms_norm_eps: f64, + #[serde(default = "default_rope")] + pub rope_theta: f32, +} + +fn default_rope() -> f32 { + 10_000.0 +} + +impl LlamaConfig { + pub fn into_config(self, use_flash_attn: bool) -> Config { + Config { + hidden_size: self.hidden_size, + intermediate_size: self.intermediate_size, + vocab_size: self.vocab_size, + num_hidden_layers: self.num_hidden_layers, + num_attention_heads: self.num_attention_heads, + num_key_value_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads), + rms_norm_eps: self.rms_norm_eps, + rope_theta: self.rope_theta, + use_flash_attn, + } + } +} + +pub struct Config { + pub hidden_size: usize, + pub intermediate_size: usize, + pub vocab_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: usize, + pub use_flash_attn: bool, + pub rms_norm_eps: f64, + pub rope_theta: f32, +} + +impl Config { + pub fn config_7b_v1(use_flash_attn: bool) -> Self { + Self { + hidden_size: 4096, + intermediate_size: 11008, + vocab_size: 32000, + num_hidden_layers: 32, + num_attention_heads: 32, + num_key_value_heads: 32, + use_flash_attn, + rms_norm_eps: 1e-6, + rope_theta: 10_000.0, + } + } + + pub fn config_7b_v2(use_flash_attn: bool) -> Self { + Self { + hidden_size: 4096, + intermediate_size: 11008, + vocab_size: 32000, + num_hidden_layers: 32, + num_attention_heads: 32, + num_key_value_heads: 32, + use_flash_attn, + rms_norm_eps: 1e-5, + rope_theta: 10_000.0, + } + } +} + +// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting +// model. +#[derive(Debug)] +pub struct Linear { + inner: candle_nn::Linear, + span: tracing::Span, +} + +impl Linear { + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} + +#[derive(Clone)] +pub struct Cache { + masks: Arc<Mutex<HashMap<usize, Tensor>>>, + pub use_kv_cache: bool, + #[allow(clippy::type_complexity)] + kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>, + cos: Tensor, + sin: Tensor, + device: Device, +} + +impl Cache { + pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> { + // precompute freqs_cis + let n_elem = config.hidden_size / config.num_attention_heads; + let theta: Vec<_> = (0..n_elem) + .step_by(2) + .map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32)) + .collect(); + let theta = Tensor::new(theta.as_slice(), device)?; + let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? + .to_dtype(DType::F32)? + .reshape((MAX_SEQ_LEN, 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; + // This is different from the paper, see: + // https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112 + let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?; + let cos = idx_theta.cos()?.to_dtype(dtype)?; + let sin = idx_theta.sin()?.to_dtype(dtype)?; + Ok(Self { + masks: Arc::new(Mutex::new(HashMap::new())), + use_kv_cache, + kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])), + device: device.clone(), + cos, + sin, + }) + } + + fn mask(&self, t: usize) -> Result<Tensor> { + let mut masks = self.masks.lock().unwrap(); + if let Some(mask) = masks.get(&t) { + Ok(mask.clone()) + } else { + let mask: Vec<_> = (0..t) + .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) + .collect(); + let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; + masks.insert(t, mask.clone()); + Ok(mask) + } + } +} + +fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { + let span = tracing::span!(tracing::Level::TRACE, "linear"); + let inner = candle_nn::linear_no_bias(size1, size2, vb)?; + Ok(Linear { inner, span }) +} + +fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> { + let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?; + Ok(Embedding::new(embeddings, cfg.hidden_size)) +} + +struct RmsNorm { + inner: candle_nn::RmsNorm, + span: tracing::Span, +} + +impl RmsNorm { + fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); + let inner = candle_nn::rms_norm(size, eps, vb)?; + Ok(Self { inner, span }) + } + + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} + +struct CausalSelfAttention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + num_attention_heads: usize, + num_key_value_heads: usize, + head_dim: usize, + cache: Cache, + use_flash_attn: bool, + span: tracing::Span, + span_rot: tracing::Span, +} + +#[cfg(feature = "flash-attn")] +fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result<Tensor> { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> { + unimplemented!("compile with '--features flash-attn'") +} + +impl CausalSelfAttention { + fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { + let _enter = self.span_rot.enter(); + let (b_sz, _, seq_len, hidden_size) = x.dims4()?; + let cos = self.cache.cos.narrow(0, index_pos, seq_len)?; + let sin = self.cache.sin.narrow(0, index_pos, seq_len)?; + let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?; + let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?; + let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?; + let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?; + let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?; + let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?; + Ok(rope) + } + + fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> { + let _enter = self.span.enter(); + let (b_sz, seq_len, hidden_size) = x.dims3()?; + let q = self.q_proj.forward(x)?; + let k = self.k_proj.forward(x)?; + let v = self.v_proj.forward(x)?; + + let q = q + .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))? + .transpose(1, 2)?; + let k = k + .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? + .transpose(1, 2)?; + let mut v = v + .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? + .transpose(1, 2)?; + + let q = self.apply_rotary_emb(&q, index_pos)?; + let mut k = self.apply_rotary_emb(&k, index_pos)?; + + if self.cache.use_kv_cache { + let mut cache = self.cache.kvs.lock().unwrap(); + if let Some((cache_k, cache_v)) = &cache[block_idx] { + k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?; + v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?; + let k_seq_len = k.dims()[1]; + if k_seq_len > MAX_SEQ_LEN { + k = k + .narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? + .contiguous()? + } + let v_seq_len = v.dims()[1]; + if v_seq_len > 2 * MAX_SEQ_LEN { + v = v + .narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? + .contiguous()? + } + } + cache[block_idx] = Some((k.clone(), v.clone())) + } + + let k = self.repeat_kv(k)?; + let v = self.repeat_kv(v)?; + + let y = if self.use_flash_attn { + // flash-attn expects (b_sz, seq_len, nheads, head_dim) + let q = q.transpose(1, 2)?; + let k = k.transpose(1, 2)?; + let v = v.transpose(1, 2)?; + let softmax_scale = 1f32 / (self.head_dim as f32).sqrt(); + flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?.transpose(1, 2)? + } else { + let in_dtype = q.dtype(); + let q = q.to_dtype(DType::F32)?; + let k = k.to_dtype(DType::F32)?; + let v = v.to_dtype(DType::F32)?; + let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; + let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?; + let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; + let att = candle_nn::ops::softmax(&att, D::Minus1)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)? + }; + let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?; + let y = self.o_proj.forward(&y)?; + Ok(y) + } + + fn repeat_kv(&self, x: Tensor) -> Result<Tensor> { + let n_rep = self.num_attention_heads / self.num_key_value_heads; + if n_rep == 1 { + Ok(x) + } else { + let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?; + let x = x + .unsqueeze(2)? + .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))? + .reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?; + Ok(x) + } + } + + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "attn"); + let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); + let size_in = cfg.hidden_size; + let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads; + let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads; + let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?; + let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?; + let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?; + let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + num_attention_heads: cfg.num_attention_heads, + num_key_value_heads: cfg.num_key_value_heads, + head_dim: cfg.hidden_size / cfg.num_attention_heads, + cache: cache.clone(), + use_flash_attn: cfg.use_flash_attn, + span, + span_rot, + }) + } +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + +struct Mlp { + c_fc1: Linear, + c_fc2: Linear, + c_proj: Linear, + span: tracing::Span, +} + +impl Mlp { + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let x = (candle_nn::ops::silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; + self.c_proj.forward(&x) + } + + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "mlp"); + let h_size = cfg.hidden_size; + let i_size = cfg.intermediate_size; + let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?; + let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?; + let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?; + Ok(Self { + c_fc1, + c_fc2, + c_proj, + span, + }) + } +} + +struct Block { + rms_1: RmsNorm, + attn: CausalSelfAttention, + rms_2: RmsNorm, + mlp: Mlp, + span: tracing::Span, +} + +impl Block { + fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> { + let _enter = self.span.enter(); + let residual = x; + let x = self.rms_1.forward(x)?; + let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?; + let residual = &x; + let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; + Ok(x) + } + + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "block"); + let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?; + let mlp = Mlp::load(vb.pp("mlp"), cfg)?; + let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let rms_2 = RmsNorm::load( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; + Ok(Self { + rms_1, + attn, + rms_2, + mlp, + span, + }) + } +} + +pub struct Llama { + wte: Embedding, + blocks: Vec<Block>, + ln_f: RmsNorm, + lm_head: Linear, +} + +impl Llama { + pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { + let (_b_sz, seq_len) = x.dims2()?; + let mut x = self.wte.forward(x)?; + for (block_idx, block) in self.blocks.iter().enumerate() { + x = block.forward(&x, index_pos, block_idx)?; + } + let x = self.ln_f.forward(&x)?; + let x = x.i((.., seq_len - 1, ..))?; + let logits = self.lm_head.forward(&x)?; + logits.to_dtype(DType::F32) + } + + pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { + let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; + let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?; + let blocks: Vec<_> = (0..cfg.num_hidden_layers) + .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap()) + .collect(); + + Ok(Self { + wte, + blocks, + ln_f, + lm_head, + }) + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 8b137891..1b3dcf25 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -1 +1,5 @@ - +pub mod bert; +pub mod bigcode; +pub mod falcon; +pub mod llama; +pub mod whisper; diff --git a/candle-transformers/src/models/whisper/audio.rs b/candle-transformers/src/models/whisper/audio.rs new file mode 100644 index 00000000..4e01de32 --- /dev/null +++ b/candle-transformers/src/models/whisper/audio.rs @@ -0,0 +1,210 @@ +// Audio processing code, adapted from whisper.cpp +// https://github.com/ggerganov/whisper.cpp + +pub trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {} + +impl Float for f32 {} +impl Float for f64 {} + +// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2357 +fn fft<T: Float>(inp: &[T]) -> Vec<T> { + let n = inp.len(); + let zero = T::zero(); + if n == 1 { + return vec![inp[0], zero]; + } + if n % 2 == 1 { + return dft(inp); + } + let mut out = vec![zero; n * 2]; + + let mut even = Vec::with_capacity(n / 2); + let mut odd = Vec::with_capacity(n / 2); + + for (i, &inp) in inp.iter().enumerate() { + if i % 2 == 0 { + even.push(inp) + } else { + odd.push(inp); + } + } + + let even_fft = fft(&even); + let odd_fft = fft(&odd); + + let two_pi = T::PI() + T::PI(); + let n_t = T::from(n).unwrap(); + for k in 0..n / 2 { + let k_t = T::from(k).unwrap(); + let theta = two_pi * k_t / n_t; + let re = theta.cos(); + let im = -theta.sin(); + + let re_odd = odd_fft[2 * k]; + let im_odd = odd_fft[2 * k + 1]; + + out[2 * k] = even_fft[2 * k] + re * re_odd - im * im_odd; + out[2 * k + 1] = even_fft[2 * k + 1] + re * im_odd + im * re_odd; + + out[2 * (k + n / 2)] = even_fft[2 * k] - re * re_odd + im * im_odd; + out[2 * (k + n / 2) + 1] = even_fft[2 * k + 1] - re * im_odd - im * re_odd; + } + out +} + +// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2337 +fn dft<T: Float>(inp: &[T]) -> Vec<T> { + let zero = T::zero(); + let n = inp.len(); + let two_pi = T::PI() + T::PI(); + + let mut out = Vec::new(); + out.reserve(2 * n); + let n_t = T::from(n).unwrap(); + for k in 0..n { + let k_t = T::from(k).unwrap(); + let mut re = zero; + let mut im = zero; + + for (j, &inp) in inp.iter().enumerate() { + let j_t = T::from(j).unwrap(); + let angle = two_pi * k_t * j_t / n_t; + re += inp * angle.cos(); + im -= inp * angle.sin(); + } + + out.push(re); + out.push(im); + } + out +} + +#[allow(clippy::too_many_arguments)] +// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2414 +fn log_mel_spectrogram_w<T: Float>( + ith: usize, + hann: &[T], + samples: &[T], + filters: &[T], + fft_size: usize, + fft_step: usize, + speed_up: bool, + n_len: usize, + n_mel: usize, + n_threads: usize, +) -> Vec<T> { + let n_fft = if speed_up { + 1 + fft_size / 4 + } else { + 1 + fft_size / 2 + }; + + let zero = T::zero(); + let half = T::from(0.5).unwrap(); + let mut fft_in = vec![zero; fft_size]; + let mut mel = vec![zero; n_len * n_mel]; + + for i in (ith..n_len).step_by(n_threads) { + let offset = i * fft_step; + + // apply Hanning window + for j in 0..fft_size { + fft_in[j] = if offset + j < samples.len() { + hann[j] * samples[offset + j] + } else { + zero + } + } + + // FFT -> mag^2 + let mut fft_out: Vec<T> = fft(&fft_in); + + for j in 0..fft_size { + fft_out[j] = fft_out[2 * j] * fft_out[2 * j] + fft_out[2 * j + 1] * fft_out[2 * j + 1]; + } + for j in 1..fft_size / 2 { + let v = fft_out[fft_size - j]; + fft_out[j] += v; + } + + if speed_up { + // scale down in the frequency domain results in a speed up in the time domain + for j in 0..n_fft { + fft_out[j] = half * (fft_out[2 * j] + fft_out[2 * j + 1]); + } + } + + // mel spectrogram + for j in 0..n_mel { + let mut sum = zero; + for k in 0..n_fft { + sum += fft_out[k] * filters[j * n_fft + k]; + } + mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10(); + } + } + mel +} + +fn log_mel_spectrogram_<T: Float + std::fmt::Display>( + samples: &[T], + filters: &[T], + fft_size: usize, + fft_step: usize, + n_mel: usize, + speed_up: bool, +) -> Vec<T> { + let zero = T::zero(); + let two_pi = T::PI() + T::PI(); + let half = T::from(0.5).unwrap(); + let one = T::from(1.0).unwrap(); + let four = T::from(4.0).unwrap(); + let fft_size_t = T::from(fft_size).unwrap(); + + let hann: Vec<T> = (0..fft_size) + .map(|i| half * (one - ((two_pi * T::from(i).unwrap()) / fft_size_t).cos())) + .collect(); + let n_len = samples.len() / fft_step; + + // pad audio with at least one extra chunk of zeros + let pad = 100 * super::CHUNK_LENGTH / 2; + let n_len = if n_len % pad != 0 { + (n_len / pad + 1) * pad + } else { + n_len + }; + let n_len = n_len + pad; + let samples = { + let mut samples_padded = samples.to_vec(); + let to_add = n_len * fft_step - samples.len(); + samples_padded.extend(std::iter::repeat(zero).take(to_add)); + samples_padded + }; + + // Use a single thread for now. + let mut mel = log_mel_spectrogram_w( + 0, &hann, &samples, filters, fft_size, fft_step, speed_up, n_len, n_mel, 1, + ); + let mmax = mel + .iter() + .max_by(|&u, &v| u.partial_cmp(v).unwrap_or(std::cmp::Ordering::Greater)) + .copied() + .unwrap_or(zero) + - T::from(8).unwrap(); + for m in mel.iter_mut() { + let v = T::max(*m, mmax); + *m = v / four + one + } + mel +} + +pub fn pcm_to_mel<T: Float + std::fmt::Display>(samples: &[T], filters: &[T]) -> Vec<T> { + log_mel_spectrogram_( + samples, + filters, + super::N_FFT, + super::HOP_LENGTH, + super::N_MELS, + false, + ) +} diff --git a/candle-transformers/src/models/whisper/mod.rs b/candle-transformers/src/models/whisper/mod.rs new file mode 100644 index 00000000..7dc8107b --- /dev/null +++ b/candle-transformers/src/models/whisper/mod.rs @@ -0,0 +1,26 @@ +pub mod audio; +pub mod model; + +pub const DTYPE: candle::DType = candle::DType::F32; + +// Audio parameters. +pub const SAMPLE_RATE: usize = 16000; +pub const N_FFT: usize = 400; +pub const N_MELS: usize = 80; +pub const HOP_LENGTH: usize = 160; +pub const CHUNK_LENGTH: usize = 30; +pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk +pub const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input + +pub const NO_SPEECH_THRESHOLD: f64 = 0.6; +pub const LOGPROB_THRESHOLD: f64 = -1.0; +pub const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]; +pub const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4; + +// Tokenizer dependent bits. +pub const SOT_TOKEN: &str = "<|startoftranscript|>"; +pub const TRANSCRIBE_TOKEN: &str = "<|transcribe|>"; +pub const TRANSLATE_TOKEN: &str = "<|translate|>"; +pub const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>"; +pub const EOT_TOKEN: &str = "<|endoftext|>"; +pub const NO_SPEECH_TOKEN: &str = "<|nocaptions|>"; diff --git a/candle-transformers/src/models/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs new file mode 100644 index 00000000..e58ab2ca --- /dev/null +++ b/candle-transformers/src/models/whisper/model.rs @@ -0,0 +1,416 @@ +use candle::{Device, IndexOp, Result, Tensor, D}; +use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; +use serde::Deserialize; + +// The names in comments correspond to the original implementation: +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17 +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + pub num_mel_bins: usize, // n_mels + pub max_source_positions: usize, // n_audio_ctx + pub d_model: usize, // n_audio_state + pub encoder_attention_heads: usize, // n_audio_head + pub encoder_layers: usize, // n_audio_layer + pub vocab_size: usize, // n_vocab + pub max_target_positions: usize, // n_text_ctx + // pub n_text_state: usize, + pub decoder_attention_heads: usize, // n_text_head + pub decoder_layers: usize, // n_text_layer + pub suppress_tokens: Vec<u32>, +} + +fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { + let embeddings = vb.get((vocab_size, hidden_size), "weight")?; + Ok(Embedding::new(embeddings, hidden_size)) +} +// +// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting +// model. +#[derive(Debug)] +pub struct Linear { + inner: candle_nn::Linear, + span: tracing::Span, +} + +impl Linear { + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} + +fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { + let span = tracing::span!(tracing::Level::TRACE, "linear"); + let inner = candle_nn::linear(size1, size2, vb)?; + Ok(Linear { inner, span }) +} + +fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { + let span = tracing::span!(tracing::Level::TRACE, "linear"); + let inner = candle_nn::linear_no_bias(size1, size2, vb)?; + Ok(Linear { inner, span }) +} + +fn conv1d( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + config: Conv1dConfig, + vb: VarBuilder, +) -> Result<Conv1d> { + let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?; + let bias = vb.get(out_channels, "bias")?; + Ok(Conv1d::new(weight, Some(bias), config)) +} + +fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> { + let weight = vb.get(size, "weight")?; + let bias = vb.get(size, "bias")?; + Ok(LayerNorm::new(weight, bias, 1e-5)) +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62 +struct MultiHeadAttention { + query: Linear, + key: Linear, + value: Linear, + out: Linear, + n_head: usize, + span: tracing::Span, + softmax_span: tracing::Span, + matmul_span: tracing::Span, + kv_cache: Option<(Tensor, Tensor)>, +} + +impl MultiHeadAttention { + fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn"); + let softmax_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-softmax"); + let matmul_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-matmul"); + let query = linear(n_state, n_state, vb.pp("q_proj"))?; + let value = linear(n_state, n_state, vb.pp("v_proj"))?; + let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?; + let out = linear(n_state, n_state, vb.pp("out_proj"))?; + Ok(Self { + query, + key, + value, + out, + n_head, + span, + softmax_span, + matmul_span, + kv_cache: None, + }) + } + + fn forward( + &mut self, + x: &Tensor, + xa: Option<&Tensor>, + mask: Option<&Tensor>, + flush_cache: bool, + ) -> Result<Tensor> { + let _enter = self.span.enter(); + let q = self.query.forward(x)?; + let (k, v) = match xa { + None => { + let k = self.key.forward(x)?; + let v = self.value.forward(x)?; + (k, v) + } + Some(x) => { + if flush_cache { + self.kv_cache = None; + } + if let Some((k, v)) = &self.kv_cache { + (k.clone(), v.clone()) + } else { + let k = self.key.forward(x)?; + let v = self.value.forward(x)?; + self.kv_cache = Some((k.clone(), v.clone())); + (k, v) + } + } + }; + let wv = self.qkv_attention(&q, &k, &v, mask)?; + let out = self.out.forward(&wv)?; + Ok(out) + } + + fn reshape_head(&self, x: &Tensor) -> Result<Tensor> { + let (n_batch, n_ctx, n_state) = x.dims3()?; + let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head]; + x.reshape(target_dims)?.transpose(1, 2) + } + + fn qkv_attention( + &self, + q: &Tensor, + k: &Tensor, + v: &Tensor, + mask: Option<&Tensor>, + ) -> Result<Tensor> { + let (_, n_ctx, n_state) = q.dims3()?; + let scale = ((n_state / self.n_head) as f64).powf(-0.25); + let q = (self.reshape_head(q)? * scale)?; + let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?; + let v = self.reshape_head(v)?.contiguous()?; + let mut qk = { + let _enter = self.matmul_span.enter(); + q.matmul(&k)? + }; + if let Some(mask) = mask { + let mask = mask.i((0..n_ctx, 0..n_ctx))?; + qk = qk.broadcast_add(&mask)? + } + let w = { + let _enter = self.softmax_span.enter(); + softmax(&qk, D::Minus1)? + }; + let wv = { + let _enter = self.matmul_span.enter(); + w.matmul(&v)? + } + .transpose(1, 2)? + .flatten_from(2)?; + Ok(wv) + } +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111 +struct ResidualAttentionBlock { + attn: MultiHeadAttention, + attn_ln: LayerNorm, + cross_attn: Option<(MultiHeadAttention, LayerNorm)>, + mlp_linear1: Linear, + mlp_linear2: Linear, + mlp_ln: LayerNorm, + span: tracing::Span, +} + +impl ResidualAttentionBlock { + fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "residual-attn"); + let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?; + let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?; + let cross_attn = if ca { + let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?; + let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?; + Some((cross_attn, cross_attn_ln)) + } else { + None + }; + let n_mlp = n_state * 4; + let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?; + let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?; + let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?; + Ok(Self { + attn, + attn_ln, + cross_attn, + mlp_linear1, + mlp_linear2, + mlp_ln, + span, + }) + } + + fn forward( + &mut self, + x: &Tensor, + xa: Option<&Tensor>, + mask: Option<&Tensor>, + flush_kv_cache: bool, + ) -> Result<Tensor> { + let _enter = self.span.enter(); + let attn = self + .attn + .forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?; + let mut x = (x + attn)?; + if let Some((attn, ln)) = &mut self.cross_attn { + x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?; + } + let mlp = self.mlp_linear2.forward( + &self + .mlp_linear1 + .forward(&self.mlp_ln.forward(&x)?)? + .gelu()?, + )?; + x + mlp + } +} + +fn sinusoids(length: usize, channels: usize) -> Result<Tensor> { + let max_timescale = 10000f32; + let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32; + let inv_timescales: Vec<_> = (0..channels / 2) + .map(|i| (i as f32 * (-log_timescale_increment)).exp()) + .collect(); + let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?; + let arange = Tensor::arange(0, length as u32, &Device::Cpu)? + .to_dtype(candle::DType::F32)? + .unsqueeze(1)?; + let sh = (length, channels / 2); + let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?; + let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?; + Ok(sincos) +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143 +pub struct AudioEncoder { + conv1: Conv1d, + conv2: Conv1d, + positional_embedding: Tensor, + blocks: Vec<ResidualAttentionBlock>, + ln_post: LayerNorm, + span: tracing::Span, + conv1_span: tracing::Span, + conv2_span: tracing::Span, +} + +impl AudioEncoder { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "audio-encoder"); + let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1"); + let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2"); + let n_state = cfg.d_model; + let n_head = cfg.encoder_attention_heads; + let n_ctx = cfg.max_source_positions; + let cfg1 = Conv1dConfig { + padding: 1, + stride: 1, + groups: 1, + dilation: 1, + }; + let cfg2 = Conv1dConfig { + padding: 1, + stride: 2, + groups: 1, + dilation: 1, + }; + let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?; + let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?; + let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?; + let blocks = (0..cfg.encoder_layers) + .map(|i| { + ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}"))) + }) + .collect::<Result<Vec<_>>>()?; + let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?; + Ok(Self { + conv1, + conv2, + positional_embedding, + blocks, + ln_post, + conv1_span, + conv2_span, + span, + }) + } + + pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> { + let _enter = self.span.enter(); + let x = { + let _enter = self.conv1_span.enter(); + self.conv1.forward(x)?.gelu()? + }; + let x = { + let _enter = self.conv2_span.enter(); + self.conv2.forward(&x)?.gelu()? + }; + let x = x.transpose(1, 2)?; + let (_bsize, seq_len, _hidden) = x.dims3()?; + let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?; + let mut x = x.broadcast_add(&positional_embedding)?; + for block in self.blocks.iter_mut() { + x = block.forward(&x, None, None, flush_kv_cache)? + } + let x = self.ln_post.forward(&x)?; + Ok(x) + } +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176 +pub struct TextDecoder { + token_embedding: Embedding, + positional_embedding: Tensor, + blocks: Vec<ResidualAttentionBlock>, + ln: LayerNorm, + mask: Tensor, + span: tracing::Span, + span_final: tracing::Span, +} + +impl TextDecoder { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "text-decoder"); + let span_final = tracing::span!(tracing::Level::TRACE, "text-decoder-final"); + let n_state = cfg.d_model; + let n_head = cfg.decoder_attention_heads; + let n_ctx = cfg.max_target_positions; + let token_embedding = embedding(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?; + let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?; + let blocks = (0..cfg.decoder_layers) + .map(|i| { + ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}"))) + }) + .collect::<Result<Vec<_>>>()?; + let ln = layer_norm(n_state, vb.pp("layer_norm"))?; + let mask: Vec<_> = (0..n_ctx) + .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) + .collect(); + let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?; + Ok(Self { + token_embedding, + positional_embedding, + blocks, + ln, + mask, + span, + span_final, + }) + } + + pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> { + let _enter = self.span.enter(); + let last = x.dim(D::Minus1)?; + let token_embedding = self.token_embedding.forward(x)?; + let positional_embedding = self.positional_embedding.narrow(0, 0, last)?; + let mut x = token_embedding.broadcast_add(&positional_embedding)?; + for block in self.blocks.iter_mut() { + x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?; + } + self.ln.forward(&x) + } + + pub fn final_linear(&self, x: &Tensor) -> Result<Tensor> { + let b_size = x.dim(0)?; + let w = self.token_embedding.embeddings().broadcast_left(b_size)?; + let logits = { + let _enter = self.span_final.enter(); + x.matmul(&w.t()?)? + }; + Ok(logits) + } +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221 +pub struct Whisper { + pub encoder: AudioEncoder, + pub decoder: TextDecoder, + pub config: Config, +} + +impl Whisper { + pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> { + let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?; + let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?; + Ok(Self { + encoder, + decoder, + config, + }) + } +} |