summaryrefslogtreecommitdiff
path: root/candle-examples/examples/musicgen/t5_model.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/musicgen/t5_model.rs')
-rw-r--r--candle-examples/examples/musicgen/t5_model.rs46
1 files changed, 15 insertions, 31 deletions
diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs
index 613b4112..33b11b95 100644
--- a/candle-examples/examples/musicgen/t5_model.rs
+++ b/candle-examples/examples/musicgen/t5_model.rs
@@ -1,10 +1,8 @@
// T5 Text Encoder
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
-use crate::nn::{embedding, linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder};
-use anyhow::Result;
-use candle::{DType, Tensor, D};
-use candle_nn::Module;
+use candle::{DType, Result, Tensor, D};
+use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, Module, VarBuilder};
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq)]
@@ -21,7 +19,7 @@ pub struct Config {
dropout_rate: f64,
layer_norm_epsilon: f64,
initializer_factor: f64,
- feed_forward_proj: HiddenAct,
+ feed_forward_proj: Activation,
is_decoder: bool,
is_encoder_decoder: bool,
use_cache: bool,
@@ -44,7 +42,7 @@ impl Default for Config {
dropout_rate: 0.1,
layer_norm_epsilon: 1e-6,
initializer_factor: 1.0,
- feed_forward_proj: HiddenAct::Relu,
+ feed_forward_proj: Activation::Relu,
is_decoder: false,
is_encoder_decoder: true,
use_cache: true,
@@ -63,7 +61,7 @@ impl Config {
d_model: 768,
dropout_rate: 0.1,
eos_token_id: 1,
- feed_forward_proj: HiddenAct::Relu,
+ feed_forward_proj: Activation::Relu,
initializer_factor: 1.0,
is_decoder: false,
is_encoder_decoder: true,
@@ -112,27 +110,23 @@ impl T5LayerNorm {
struct T5DenseActDense {
wi: Linear,
wo: Linear,
- dropout: Dropout,
- act: HiddenAct,
+ act: Activation,
}
impl T5DenseActDense {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
- let wi = linear(cfg.d_model, cfg.d_ff, false, vb.pp("wi"))?;
- let wo = linear(cfg.d_ff, cfg.d_model, false, vb.pp("wo"))?;
- let dropout = Dropout::new(cfg.dropout_rate);
+ let wi = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
+ let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
Ok(Self {
wi,
wo,
- dropout,
- act: HiddenAct::Relu,
+ act: Activation::Relu,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.wi.forward(xs)?;
let xs = self.act.forward(&xs)?;
- let xs = self.dropout.forward(&xs)?;
let xs = self.wo.forward(&xs)?;
Ok(xs)
}
@@ -142,7 +136,6 @@ impl T5DenseActDense {
struct T5LayerFF {
dense_relu_dense: T5DenseActDense,
layer_norm: T5LayerNorm,
- dropout: Dropout,
}
impl T5LayerFF {
@@ -151,18 +144,16 @@ impl T5LayerFF {
let dense_relu_dense = T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?;
let layer_norm =
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
- let dropout = Dropout::new(cfg.dropout_rate);
Ok(Self {
dense_relu_dense,
layer_norm,
- dropout,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let ys = self.layer_norm.forward(xs)?;
let ys = self.dense_relu_dense.forward(&ys)?;
- let xs = (xs + self.dropout.forward(&ys)?)?;
+ let xs = (xs + ys)?;
Ok(xs)
}
}
@@ -181,10 +172,10 @@ struct T5Attention {
impl T5Attention {
fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
let inner_dim = cfg.num_heads * cfg.d_kv;
- let q = linear(cfg.d_model, inner_dim, false, vb.pp("q"))?;
- let k = linear(cfg.d_model, inner_dim, false, vb.pp("k"))?;
- let v = linear(cfg.d_model, inner_dim, false, vb.pp("v"))?;
- let o = linear(inner_dim, cfg.d_model, false, vb.pp("o"))?;
+ let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?;
+ let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?;
+ let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?;
+ let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?;
let relative_attention_bias = if h {
let emb = embedding(
cfg.relative_attention_num_buckets,
@@ -235,7 +226,6 @@ impl T5Attention {
struct T5LayerSelfAttention {
self_attention: T5Attention,
layer_norm: T5LayerNorm,
- dropout: Dropout,
}
impl T5LayerSelfAttention {
@@ -243,11 +233,9 @@ impl T5LayerSelfAttention {
let self_attention = T5Attention::load(h, vb.pp("SelfAttention"), cfg)?;
let layer_norm =
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
- let dropout = Dropout::new(cfg.dropout_rate);
Ok(Self {
self_attention,
layer_norm,
- dropout,
})
}
@@ -315,7 +303,6 @@ struct T5Stack {
block: Vec<T5Block>,
shared: Arc<Embedding>,
final_layer_norm: T5LayerNorm,
- dropout: Dropout,
}
impl T5Stack {
@@ -328,12 +315,10 @@ impl T5Stack {
cfg.layer_norm_epsilon,
vb.pp("final_layer_norm"),
)?;
- let dropout = Dropout::new(cfg.dropout_rate);
Ok(Self {
block,
shared: shared.clone(),
final_layer_norm,
- dropout,
})
}
@@ -341,12 +326,11 @@ impl T5Stack {
let input_embeds = self.shared.as_ref().forward(input_ids)?;
let (_b_sz, _seq_len) = input_embeds.dims2()?;
- let mut hidden_states = self.dropout.forward(&input_embeds)?;
+ let mut hidden_states = input_embeds;
for block in self.block.iter() {
hidden_states = block.forward(&hidden_states)?
}
let hidden_states = self.final_layer_norm.forward(&hidden_states)?;
- let hidden_states = self.dropout.forward(&hidden_states)?;
Ok(hidden_states)
}
}