summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/t5.rs
diff options
context:
space:
mode:
authorJuarez Bochi <juarez.bochi@grammarly.com>2023-09-13 10:27:20 -0700
committerGitHub <noreply@github.com>2023-09-13 19:27:20 +0200
commit49d3f7f70814bd0e8b569f93bb76419306359251 (patch)
tree9df579d302c2233ad3015996422e5b6f9b5a0436 /candle-transformers/src/models/t5.rs
parent9a465e1b2601195b65f1422f16793d6825252231 (diff)
downloadcandle-49d3f7f70814bd0e8b569f93bb76419306359251.tar.gz
candle-49d3f7f70814bd0e8b569f93bb76419306359251.tar.bz2
candle-49d3f7f70814bd0e8b569f93bb76419306359251.zip
Add support to flan-t5 (#840)
Diffstat (limited to 'candle-transformers/src/models/t5.rs')
-rw-r--r--candle-transformers/src/models/t5.rs54
1 files changed, 49 insertions, 5 deletions
diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs
index 325eb752..de7de496 100644
--- a/candle-transformers/src/models/t5.rs
+++ b/candle-transformers/src/models/t5.rs
@@ -149,26 +149,70 @@ impl T5DenseActDense {
}
#[derive(Debug)]
+struct T5DenseGatedActDense {
+ wi_0: Linear,
+ wi_1: Linear,
+ wo: Linear,
+ act: Activation,
+}
+
+impl T5DenseGatedActDense {
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let wi_0 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?;
+ let wi_1 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?;
+ let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
+ Ok(Self {
+ wi_0,
+ wi_1,
+ wo,
+ act: Activation::NewGelu,
+ })
+ }
+
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let hidden_gelu = self.act.forward(&self.wi_0.forward(xs)?)?;
+ let hidden_linear = self.wi_1.forward(xs)?;
+ let xs = hidden_gelu.broadcast_mul(&hidden_linear)?;
+ let xs = self.wo.forward(&xs)?;
+ Ok(xs)
+ }
+}
+
+#[derive(Debug)]
struct T5LayerFF {
- dense_relu_dense: T5DenseActDense,
+ dense_act: Option<T5DenseActDense>,
+ gated_dense_act: Option<T5DenseGatedActDense>,
layer_norm: T5LayerNorm,
}
impl T5LayerFF {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
- // is_gated_act is not supported.
- let dense_relu_dense = T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?;
let layer_norm =
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
+ let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu {
+ (
+ None,
+ Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
+ )
+ } else {
+ (
+ Some(T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?),
+ None,
+ )
+ };
Ok(Self {
- dense_relu_dense,
+ dense_act,
+ gated_dense_act,
layer_norm,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let ys = self.layer_norm.forward(xs)?;
- let ys = self.dense_relu_dense.forward(&ys)?;
+ let ys = match &self.dense_act {
+ Some(dense_act) => dense_act.forward(&ys)?,
+ None => self.gated_dense_act.as_ref().unwrap().forward(&ys)?,
+ };
let xs = (xs + ys)?;
Ok(xs)
}