diff options
Diffstat (limited to 'candle-examples/examples/musicgen/t5_model.rs')
-rw-r--r-- | candle-examples/examples/musicgen/t5_model.rs | 46 |
1 files changed, 15 insertions, 31 deletions
diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs index 613b4112..33b11b95 100644 --- a/candle-examples/examples/musicgen/t5_model.rs +++ b/candle-examples/examples/musicgen/t5_model.rs @@ -1,10 +1,8 @@ // T5 Text Encoder // https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py -use crate::nn::{embedding, linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder}; -use anyhow::Result; -use candle::{DType, Tensor, D}; -use candle_nn::Module; +use candle::{DType, Result, Tensor, D}; +use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, Module, VarBuilder}; use std::sync::Arc; #[derive(Debug, Clone, PartialEq)] @@ -21,7 +19,7 @@ pub struct Config { dropout_rate: f64, layer_norm_epsilon: f64, initializer_factor: f64, - feed_forward_proj: HiddenAct, + feed_forward_proj: Activation, is_decoder: bool, is_encoder_decoder: bool, use_cache: bool, @@ -44,7 +42,7 @@ impl Default for Config { dropout_rate: 0.1, layer_norm_epsilon: 1e-6, initializer_factor: 1.0, - feed_forward_proj: HiddenAct::Relu, + feed_forward_proj: Activation::Relu, is_decoder: false, is_encoder_decoder: true, use_cache: true, @@ -63,7 +61,7 @@ impl Config { d_model: 768, dropout_rate: 0.1, eos_token_id: 1, - feed_forward_proj: HiddenAct::Relu, + feed_forward_proj: Activation::Relu, initializer_factor: 1.0, is_decoder: false, is_encoder_decoder: true, @@ -112,27 +110,23 @@ impl T5LayerNorm { struct T5DenseActDense { wi: Linear, wo: Linear, - dropout: Dropout, - act: HiddenAct, + act: Activation, } impl T5DenseActDense { fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let wi = linear(cfg.d_model, cfg.d_ff, false, vb.pp("wi"))?; - let wo = linear(cfg.d_ff, cfg.d_model, false, vb.pp("wo"))?; - let dropout = Dropout::new(cfg.dropout_rate); + let wi = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi"))?; + let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?; Ok(Self { wi, wo, - dropout, - act: HiddenAct::Relu, + act: Activation::Relu, }) } fn forward(&self, xs: &Tensor) -> Result<Tensor> { let xs = self.wi.forward(xs)?; let xs = self.act.forward(&xs)?; - let xs = self.dropout.forward(&xs)?; let xs = self.wo.forward(&xs)?; Ok(xs) } @@ -142,7 +136,6 @@ impl T5DenseActDense { struct T5LayerFF { dense_relu_dense: T5DenseActDense, layer_norm: T5LayerNorm, - dropout: Dropout, } impl T5LayerFF { @@ -151,18 +144,16 @@ impl T5LayerFF { let dense_relu_dense = T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?; let layer_norm = T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; - let dropout = Dropout::new(cfg.dropout_rate); Ok(Self { dense_relu_dense, layer_norm, - dropout, }) } fn forward(&self, xs: &Tensor) -> Result<Tensor> { let ys = self.layer_norm.forward(xs)?; let ys = self.dense_relu_dense.forward(&ys)?; - let xs = (xs + self.dropout.forward(&ys)?)?; + let xs = (xs + ys)?; Ok(xs) } } @@ -181,10 +172,10 @@ struct T5Attention { impl T5Attention { fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> { let inner_dim = cfg.num_heads * cfg.d_kv; - let q = linear(cfg.d_model, inner_dim, false, vb.pp("q"))?; - let k = linear(cfg.d_model, inner_dim, false, vb.pp("k"))?; - let v = linear(cfg.d_model, inner_dim, false, vb.pp("v"))?; - let o = linear(inner_dim, cfg.d_model, false, vb.pp("o"))?; + let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?; + let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?; + let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?; + let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?; let relative_attention_bias = if h { let emb = embedding( cfg.relative_attention_num_buckets, @@ -235,7 +226,6 @@ impl T5Attention { struct T5LayerSelfAttention { self_attention: T5Attention, layer_norm: T5LayerNorm, - dropout: Dropout, } impl T5LayerSelfAttention { @@ -243,11 +233,9 @@ impl T5LayerSelfAttention { let self_attention = T5Attention::load(h, vb.pp("SelfAttention"), cfg)?; let layer_norm = T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; - let dropout = Dropout::new(cfg.dropout_rate); Ok(Self { self_attention, layer_norm, - dropout, }) } @@ -315,7 +303,6 @@ struct T5Stack { block: Vec<T5Block>, shared: Arc<Embedding>, final_layer_norm: T5LayerNorm, - dropout: Dropout, } impl T5Stack { @@ -328,12 +315,10 @@ impl T5Stack { cfg.layer_norm_epsilon, vb.pp("final_layer_norm"), )?; - let dropout = Dropout::new(cfg.dropout_rate); Ok(Self { block, shared: shared.clone(), final_layer_norm, - dropout, }) } @@ -341,12 +326,11 @@ impl T5Stack { let input_embeds = self.shared.as_ref().forward(input_ids)?; let (_b_sz, _seq_len) = input_embeds.dims2()?; - let mut hidden_states = self.dropout.forward(&input_embeds)?; + let mut hidden_states = input_embeds; for block in self.block.iter() { hidden_states = block.forward(&hidden_states)? } let hidden_states = self.final_layer_norm.forward(&hidden_states)?; - let hidden_states = self.dropout.forward(&hidden_states)?; Ok(hidden_states) } } |