diff options
author | Juarez Bochi <juarez.bochi@grammarly.com> | 2023-09-13 10:27:20 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-13 19:27:20 +0200 |
commit | 49d3f7f70814bd0e8b569f93bb76419306359251 (patch) | |
tree | 9df579d302c2233ad3015996422e5b6f9b5a0436 /candle-transformers/src/models/t5.rs | |
parent | 9a465e1b2601195b65f1422f16793d6825252231 (diff) | |
download | candle-49d3f7f70814bd0e8b569f93bb76419306359251.tar.gz candle-49d3f7f70814bd0e8b569f93bb76419306359251.tar.bz2 candle-49d3f7f70814bd0e8b569f93bb76419306359251.zip |
Add support to flan-t5 (#840)
Diffstat (limited to 'candle-transformers/src/models/t5.rs')
-rw-r--r-- | candle-transformers/src/models/t5.rs | 54 |
1 files changed, 49 insertions, 5 deletions
diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 325eb752..de7de496 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -149,26 +149,70 @@ impl T5DenseActDense { } #[derive(Debug)] +struct T5DenseGatedActDense { + wi_0: Linear, + wi_1: Linear, + wo: Linear, + act: Activation, +} + +impl T5DenseGatedActDense { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let wi_0 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?; + let wi_1 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?; + let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?; + Ok(Self { + wi_0, + wi_1, + wo, + act: Activation::NewGelu, + }) + } + + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let hidden_gelu = self.act.forward(&self.wi_0.forward(xs)?)?; + let hidden_linear = self.wi_1.forward(xs)?; + let xs = hidden_gelu.broadcast_mul(&hidden_linear)?; + let xs = self.wo.forward(&xs)?; + Ok(xs) + } +} + +#[derive(Debug)] struct T5LayerFF { - dense_relu_dense: T5DenseActDense, + dense_act: Option<T5DenseActDense>, + gated_dense_act: Option<T5DenseGatedActDense>, layer_norm: T5LayerNorm, } impl T5LayerFF { fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - // is_gated_act is not supported. - let dense_relu_dense = T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?; let layer_norm = T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; + let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu { + ( + None, + Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?), + ) + } else { + ( + Some(T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?), + None, + ) + }; Ok(Self { - dense_relu_dense, + dense_act, + gated_dense_act, layer_norm, }) } fn forward(&self, xs: &Tensor) -> Result<Tensor> { let ys = self.layer_norm.forward(xs)?; - let ys = self.dense_relu_dense.forward(&ys)?; + let ys = match &self.dense_act { + Some(dense_act) => dense_act.forward(&ys)?, + None => self.gated_dense_act.as_ref().unwrap().forward(&ys)?, + }; let xs = (xs + ys)?; Ok(xs) } |