summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-02-28 21:02:41 +0100
committerGitHub <noreply@github.com>2024-02-28 21:02:41 +0100
commit4fd00b890036ef67391a9cc03f896247d0a75711 (patch)
tree19121cce4cab5406e9fda202de71a32427c096bd /candle-nn
parent57267cd53612ede04090853680125b17956804f3 (diff)
downloadcandle-4fd00b890036ef67391a9cc03f896247d0a75711.tar.gz
candle-4fd00b890036ef67391a9cc03f896247d0a75711.tar.bz2
candle-4fd00b890036ef67391a9cc03f896247d0a75711.zip
Add the StarCoder2 model. (#1779)
* Add the StarCoder2 model. * Add the example code and get things to work. * And also tweak the readme.
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/activation.rs4
1 files changed, 4 insertions, 0 deletions
diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs
index 60a7a6d1..b9745375 100644
--- a/candle-nn/src/activation.rs
+++ b/candle-nn/src/activation.rs
@@ -5,6 +5,7 @@ use serde::Deserialize;
#[serde(rename_all = "lowercase")]
pub enum Activation {
#[default]
+ #[serde(alias = "gelu")]
Gelu,
#[serde(alias = "gelu_new")]
NewGelu,
@@ -19,6 +20,8 @@ pub enum Activation {
HardSwish,
Elu(f64),
LeakyRelu(f64),
+ #[serde(alias = "gelu_pytorch_tanh")]
+ GeluPytorchTanh,
}
impl super::Module for Activation {
@@ -38,6 +41,7 @@ impl super::Module for Activation {
Self::HardSwish => xs * crate::ops::hard_sigmoid(xs)?,
&Self::Elu(alpha) => xs.elu(alpha),
&Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope),
+ Self::GeluPytorchTanh => xs.gelu(),
}
}
}