summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-10 21:19:21 +0200
committerGitHub <noreply@github.com>2024-04-10 21:19:21 +0200
commita0460cd2b13a396ff8545dc1bbffa741f2ec3d79 (patch)
tree9910e17010354b5edbcae1269451eccae300da89
parentb81ecf712d1854598d6c9f9cfa06fbf0093f3bc9 (diff)
downloadcandle-a0460cd2b13a396ff8545dc1bbffa741f2ec3d79.tar.gz
candle-a0460cd2b13a396ff8545dc1bbffa741f2ec3d79.tar.bz2
candle-a0460cd2b13a396ff8545dc1bbffa741f2ec3d79.zip
Add the code-gemma models. (#2038)
* Add the code-gemma models. * Tweak to the gemma config.
-rw-r--r--candle-examples/examples/gemma/main.rs12
-rw-r--r--candle-transformers/src/models/gemma.rs19
2 files changed, 27 insertions, 4 deletions
diff --git a/candle-examples/examples/gemma/main.rs b/candle-examples/examples/gemma/main.rs
index 0e37f5cd..a5f7d591 100644
--- a/candle-examples/examples/gemma/main.rs
+++ b/candle-examples/examples/gemma/main.rs
@@ -30,6 +30,14 @@ enum Which {
InstructV1_1_2B,
#[value(name = "1.1-7b-it")]
InstructV1_1_7B,
+ #[value(name = "code-2b")]
+ CodeBase2B,
+ #[value(name = "code-7b")]
+ CodeBase7B,
+ #[value(name = "code-2b-it")]
+ CodeInstruct2B,
+ #[value(name = "code-7b-it")]
+ CodeInstruct7B,
}
struct TextGeneration {
@@ -224,6 +232,10 @@ fn main() -> Result<()> {
Which::Base7B => "google/gemma-7b".to_string(),
Which::Instruct2B => "google/gemma-2b-it".to_string(),
Which::Instruct7B => "google/gemma-7b-it".to_string(),
+ Which::CodeBase2B => "google/codegemma-2b".to_string(),
+ Which::CodeBase7B => "google/codegemma-7b".to_string(),
+ Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
+ Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
},
};
let repo = api.repo(Repo::with_revision(
diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs
index ab2a9582..15e4dccb 100644
--- a/candle-transformers/src/models/gemma.rs
+++ b/candle-transformers/src/models/gemma.rs
@@ -1,7 +1,7 @@
use std::sync::Arc;
use candle::{DType, Device, Module, Result, Tensor, D};
-use candle_nn::{linear_b as linear, Linear, VarBuilder};
+use candle_nn::{linear_b as linear, Activation, Linear, VarBuilder};
fn default_max_position_embeddings() -> usize {
4096
@@ -11,8 +11,9 @@ fn default_max_position_embeddings() -> usize {
pub struct Config {
pub attention_bias: bool,
pub head_dim: usize,
- #[serde(alias = "hidden_activation")]
- pub hidden_act: candle_nn::Activation,
+ // The code gemma configs include both hidden_act and hidden_activation.
+ pub hidden_act: Option<Activation>,
+ pub hidden_activation: Option<Activation>,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_attention_heads: usize,
@@ -26,6 +27,16 @@ pub struct Config {
pub max_position_embeddings: usize,
}
+impl Config {
+ fn hidden_act(&self) -> Result<Activation> {
+ match (self.hidden_act, self.hidden_activation) {
+ (None, Some(act)) | (Some(act), None) => Ok(act),
+ (Some(_), Some(_)) => candle::bail!("both hidden_act and hidden_activation are set"),
+ (None, None) => candle::bail!("none of hidden_act and hidden_activation are set"),
+ }
+ }
+}
+
#[derive(Debug, Clone)]
struct RmsNorm {
weight: Tensor,
@@ -127,7 +138,7 @@ impl MLP {
gate_proj,
up_proj,
down_proj,
- act_fn: cfg.hidden_act,
+ act_fn: cfg.hidden_act()?,
})
}
}