summaryrefslogtreecommitdiff
path: root/candle-transformers/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src')
-rw-r--r--candle-transformers/src/models/gemma.rs19
1 files changed, 15 insertions, 4 deletions
diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs
index ab2a9582..15e4dccb 100644
--- a/candle-transformers/src/models/gemma.rs
+++ b/candle-transformers/src/models/gemma.rs
@@ -1,7 +1,7 @@
use std::sync::Arc;
use candle::{DType, Device, Module, Result, Tensor, D};
-use candle_nn::{linear_b as linear, Linear, VarBuilder};
+use candle_nn::{linear_b as linear, Activation, Linear, VarBuilder};
fn default_max_position_embeddings() -> usize {
4096
@@ -11,8 +11,9 @@ fn default_max_position_embeddings() -> usize {
pub struct Config {
pub attention_bias: bool,
pub head_dim: usize,
- #[serde(alias = "hidden_activation")]
- pub hidden_act: candle_nn::Activation,
+ // The code gemma configs include both hidden_act and hidden_activation.
+ pub hidden_act: Option<Activation>,
+ pub hidden_activation: Option<Activation>,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_attention_heads: usize,
@@ -26,6 +27,16 @@ pub struct Config {
pub max_position_embeddings: usize,
}
+impl Config {
+ fn hidden_act(&self) -> Result<Activation> {
+ match (self.hidden_act, self.hidden_activation) {
+ (None, Some(act)) | (Some(act), None) => Ok(act),
+ (Some(_), Some(_)) => candle::bail!("both hidden_act and hidden_activation are set"),
+ (None, None) => candle::bail!("none of hidden_act and hidden_activation are set"),
+ }
+ }
+}
+
#[derive(Debug, Clone)]
struct RmsNorm {
weight: Tensor,
@@ -127,7 +138,7 @@ impl MLP {
gate_proj,
up_proj,
down_proj,
- act_fn: cfg.hidden_act,
+ act_fn: cfg.hidden_act()?,
})
}
}