diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-22 18:52:00 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-22 18:52:00 +0200 |
commit | b2e816752bb3b81ed5daaf4b623c3b5e6c0f7b67 (patch) | |
tree | edf791a5d2a02bdcb28ae36d544e57e27e89772f /candle-nn | |
parent | 618ecf5e231beb5bd0b1e59a171eb9cb0af95b01 (diff) | |
download | candle-b2e816752bb3b81ed5daaf4b623c3b5e6c0f7b67.tar.gz candle-b2e816752bb3b81ed5daaf4b623c3b5e6c0f7b67.tar.bz2 candle-b2e816752bb3b81ed5daaf4b623c3b5e6c0f7b67.zip |
Use the faster rms-norm kernel for llama. (#2107)
* Use the faster rms-norm kernel for llama.
* Use the fast variant by default.
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/layer_norm.rs | 17 |
1 files changed, 13 insertions, 4 deletions
diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs index 26800a7b..23d0c01b 100644 --- a/candle-nn/src/layer_norm.rs +++ b/candle-nn/src/layer_norm.rs @@ -28,7 +28,7 @@ //! ``` //! //! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450 -use candle::{DType, Result, Tensor, D}; +use candle::{DType, Module, Result, Tensor, D}; #[derive(Debug, Clone, Copy, PartialEq)] pub struct LayerNormConfig { @@ -105,7 +105,7 @@ impl LayerNorm { } } -impl crate::Module for LayerNorm { +impl Module for LayerNorm { fn forward(&self, x: &Tensor) -> Result<Tensor> { let x_dtype = x.dtype(); let internal_dtype = match x_dtype { @@ -162,11 +162,20 @@ impl RmsNorm { pub fn into_inner(self) -> LayerNorm { self.0 } + + /// Faster variant of the forward kernel, this can only be used on contiguous tensors though. + pub fn forward_diff(&self, xs: &Tensor) -> Result<Tensor> { + self.0.forward(xs) + } } -impl crate::Module for RmsNorm { +impl Module for RmsNorm { fn forward(&self, xs: &Tensor) -> Result<Tensor> { - self.0.forward(xs) + if xs.is_contiguous() { + crate::ops::rms_norm(xs, &self.0.weight, self.0.eps as f32) + } else { + self.0.forward(xs) + } } } |