summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-26 22:30:21 +0100
committerGitHub <noreply@github.com>2023-09-26 22:30:21 +0100
commitce0a4e3a85c40f2b46ce2ee5f58ab56c30f38d99 (patch)
tree0010e960666b6554328b31697d764eddded45461
parent4abc1ea34dbc834e561f442737faf2c735f0a6ce (diff)
downloadcandle-ce0a4e3a85c40f2b46ce2ee5f58ab56c30f38d99.tar.gz
candle-ce0a4e3a85c40f2b46ce2ee5f58ab56c30f38d99.tar.bz2
candle-ce0a4e3a85c40f2b46ce2ee5f58ab56c30f38d99.zip
Use the gelu-erf activation. (#969)
-rw-r--r--candle-core/src/quantized/utils.rs6
-rw-r--r--candle-nn/src/activation.rs4
-rw-r--r--candle-transformers/src/models/bert.rs4
3 files changed, 5 insertions, 9 deletions
diff --git a/candle-core/src/quantized/utils.rs b/candle-core/src/quantized/utils.rs
index edbffa35..fa6eff51 100644
--- a/candle-core/src/quantized/utils.rs
+++ b/candle-core/src/quantized/utils.rs
@@ -17,7 +17,7 @@ pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
let expected_blocks = xs.len() / block_size;
let actual_blocks = ys.len();
- //validate that the input is the right size
+ // Validate that the input is the right size
if expected_blocks != actual_blocks {
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
}
@@ -37,12 +37,12 @@ pub(super) fn group_for_dequantization<'a, 'b, T: super::k_quants::GgmlType>(
let actual_output_len = ys.len();
let expected_output_len = xs.len() * block_size;
- //validate that the output is the right size
+ // Validate that the output is the right size
if expected_output_len != actual_output_len {
crate::bail!("dequantize {dtype:?}: ys (len = {actual_output_len}) does not match the expected length of {expected_output_len}!")
}
- //zip the blocks and outputs together
+ // Zip the blocks and outputs together
Ok(xs.iter().zip(ys.chunks_exact_mut(block_size)).collect())
}
diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs
index 17467b31..1e67ed53 100644
--- a/candle-nn/src/activation.rs
+++ b/candle-nn/src/activation.rs
@@ -16,9 +16,7 @@ pub enum Activation {
impl super::Module for Activation {
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
match self {
- Self::Gelu => xs.gelu(),
- // TODO: This is "gelu_new", not the original "gelu".
- // There's some small numerical difference:
+ Self::Gelu => xs.gelu_erf(),
// https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78
Self::NewGelu => xs.gelu(),
Self::Relu => xs.relu(),
diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs
index 3f164a3a..8af34465 100644
--- a/candle-transformers/src/models/bert.rs
+++ b/candle-transformers/src/models/bert.rs
@@ -25,10 +25,8 @@ impl HiddenActLayer {
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
let _enter = self.span.enter();
match self.act {
- // TODO: The all-MiniLM-L6-v2 model uses "gelu" whereas this is "gelu_new", this explains some
- // small numerical difference.
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
- HiddenAct::Gelu => xs.gelu(),
+ HiddenAct::Gelu => xs.gelu_erf(),
HiddenAct::Relu => xs.relu(),
}
}