summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/persimmon.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-27 21:51:16 +0200
committerGitHub <noreply@github.com>2023-10-27 20:51:16 +0100
commitc8face3f95a9c57b4714cd95dc69237533558c25 (patch)
treefc2f6b0aa8c2f71793f71293fdbd14b10dcc9575 /candle-transformers/src/models/persimmon.rs
parent85bea43e5b088b94612b0fd7ed8f09261dc79d52 (diff)
downloadcandle-c8face3f95a9c57b4714cd95dc69237533558c25.tar.gz
candle-c8face3f95a9c57b4714cd95dc69237533558c25.tar.bz2
candle-c8face3f95a9c57b4714cd95dc69237533558c25.zip
Add the relu2 and relu6 activations. (#1201)
Diffstat (limited to 'candle-transformers/src/models/persimmon.rs')
-rw-r--r--candle-transformers/src/models/persimmon.rs56
1 files changed, 56 insertions, 0 deletions
diff --git a/candle-transformers/src/models/persimmon.rs b/candle-transformers/src/models/persimmon.rs
new file mode 100644
index 00000000..afee7c83
--- /dev/null
+++ b/candle-transformers/src/models/persimmon.rs
@@ -0,0 +1,56 @@
+use candle::DType;
+use serde::Deserialize;
+
+pub const DTYPE: DType = DType::F32;
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
+#[serde(rename_all = "lowercase")]
+pub enum PositionEmbeddingType {
+ Absolute,
+ Alibi,
+}
+
+// https://github.com/huggingface/transformers/blob/main/src/transformers/models/persimmon/configuration_persimmon.py
+#[derive(Debug, Clone, PartialEq, Deserialize)]
+pub struct Config {
+ pub vocab_size: usize,
+ pub hidden_size: usize,
+ pub intermediate_size: usize,
+ pub num_hidden_layers: usize,
+ pub num_attention_heads: usize,
+ pub num_key_value_heads: usize,
+ pub hidden_act: candle_nn::Activation,
+ pub max_position_embeddings: usize,
+ pub initializer_range: f64,
+ pub layer_norm_eps: f64,
+ pub rms_norm_eps: f64,
+ pub use_cache: bool,
+ pub tie_word_embeddings: bool,
+ pub rope_theta: f64,
+ pub qk_layernorm: bool,
+ pub partial_rotary_factor: f64,
+}
+
+impl Config {
+ pub fn base_8b() -> Self {
+ // https://huggingface.co/adept/persimmon-8b-base/blob/main/config.json
+ Self {
+ hidden_act: candle_nn::Activation::Relu,
+ hidden_size: 4096,
+ initializer_range: 0.02,
+ intermediate_size: 16384,
+ layer_norm_eps: 1e-05,
+ max_position_embeddings: 16384,
+ num_attention_heads: 64,
+ num_hidden_layers: 36,
+ num_key_value_heads: 64,
+ qk_layernorm: true,
+ rms_norm_eps: 1e-06,
+ rope_theta: 25000.0,
+ tie_word_embeddings: false,
+ use_cache: true,
+ vocab_size: 262144,
+ partial_rotary_factor: 0.5,
+ }
+ }
+}