diff options
Diffstat (limited to 'candle-transformers/src')
-rw-r--r-- | candle-transformers/src/models/gemma.rs | 19 |
1 files changed, 15 insertions, 4 deletions
diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs index ab2a9582..15e4dccb 100644 --- a/candle-transformers/src/models/gemma.rs +++ b/candle-transformers/src/models/gemma.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use candle::{DType, Device, Module, Result, Tensor, D}; -use candle_nn::{linear_b as linear, Linear, VarBuilder}; +use candle_nn::{linear_b as linear, Activation, Linear, VarBuilder}; fn default_max_position_embeddings() -> usize { 4096 @@ -11,8 +11,9 @@ fn default_max_position_embeddings() -> usize { pub struct Config { pub attention_bias: bool, pub head_dim: usize, - #[serde(alias = "hidden_activation")] - pub hidden_act: candle_nn::Activation, + // The code gemma configs include both hidden_act and hidden_activation. + pub hidden_act: Option<Activation>, + pub hidden_activation: Option<Activation>, pub hidden_size: usize, pub intermediate_size: usize, pub num_attention_heads: usize, @@ -26,6 +27,16 @@ pub struct Config { pub max_position_embeddings: usize, } +impl Config { + fn hidden_act(&self) -> Result<Activation> { + match (self.hidden_act, self.hidden_activation) { + (None, Some(act)) | (Some(act), None) => Ok(act), + (Some(_), Some(_)) => candle::bail!("both hidden_act and hidden_activation are set"), + (None, None) => candle::bail!("none of hidden_act and hidden_activation are set"), + } + } +} + #[derive(Debug, Clone)] struct RmsNorm { weight: Tensor, @@ -127,7 +138,7 @@ impl MLP { gate_proj, up_proj, down_proj, - act_fn: cfg.hidden_act, + act_fn: cfg.hidden_act()?, }) } } |