diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-10 21:19:21 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-10 21:19:21 +0200 |
commit | a0460cd2b13a396ff8545dc1bbffa741f2ec3d79 (patch) | |
tree | 9910e17010354b5edbcae1269451eccae300da89 /candle-transformers/src | |
parent | b81ecf712d1854598d6c9f9cfa06fbf0093f3bc9 (diff) | |
download | candle-a0460cd2b13a396ff8545dc1bbffa741f2ec3d79.tar.gz candle-a0460cd2b13a396ff8545dc1bbffa741f2ec3d79.tar.bz2 candle-a0460cd2b13a396ff8545dc1bbffa741f2ec3d79.zip |
Add the code-gemma models. (#2038)
* Add the code-gemma models.
* Tweak to the gemma config.
Diffstat (limited to 'candle-transformers/src')
-rw-r--r-- | candle-transformers/src/models/gemma.rs | 19 |
1 files changed, 15 insertions, 4 deletions
diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs index ab2a9582..15e4dccb 100644 --- a/candle-transformers/src/models/gemma.rs +++ b/candle-transformers/src/models/gemma.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use candle::{DType, Device, Module, Result, Tensor, D}; -use candle_nn::{linear_b as linear, Linear, VarBuilder}; +use candle_nn::{linear_b as linear, Activation, Linear, VarBuilder}; fn default_max_position_embeddings() -> usize { 4096 @@ -11,8 +11,9 @@ fn default_max_position_embeddings() -> usize { pub struct Config { pub attention_bias: bool, pub head_dim: usize, - #[serde(alias = "hidden_activation")] - pub hidden_act: candle_nn::Activation, + // The code gemma configs include both hidden_act and hidden_activation. + pub hidden_act: Option<Activation>, + pub hidden_activation: Option<Activation>, pub hidden_size: usize, pub intermediate_size: usize, pub num_attention_heads: usize, @@ -26,6 +27,16 @@ pub struct Config { pub max_position_embeddings: usize, } +impl Config { + fn hidden_act(&self) -> Result<Activation> { + match (self.hidden_act, self.hidden_activation) { + (None, Some(act)) | (Some(act), None) => Ok(act), + (Some(_), Some(_)) => candle::bail!("both hidden_act and hidden_activation are set"), + (None, None) => candle::bail!("none of hidden_act and hidden_activation are set"), + } + } +} + #[derive(Debug, Clone)] struct RmsNorm { weight: Tensor, @@ -127,7 +138,7 @@ impl MLP { gate_proj, up_proj, down_proj, - act_fn: cfg.hidden_act, + act_fn: cfg.hidden_act()?, }) } } |