summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-22 18:52:00 +0200
committerGitHub <noreply@github.com>2024-04-22 18:52:00 +0200
commitb2e816752bb3b81ed5daaf4b623c3b5e6c0f7b67 (patch)
treeedf791a5d2a02bdcb28ae36d544e57e27e89772f /candle-nn
parent618ecf5e231beb5bd0b1e59a171eb9cb0af95b01 (diff)
downloadcandle-b2e816752bb3b81ed5daaf4b623c3b5e6c0f7b67.tar.gz
candle-b2e816752bb3b81ed5daaf4b623c3b5e6c0f7b67.tar.bz2
candle-b2e816752bb3b81ed5daaf4b623c3b5e6c0f7b67.zip
Use the faster rms-norm kernel for llama. (#2107)
* Use the faster rms-norm kernel for llama. * Use the fast variant by default.
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/layer_norm.rs17
1 files changed, 13 insertions, 4 deletions
diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs
index 26800a7b..23d0c01b 100644
--- a/candle-nn/src/layer_norm.rs
+++ b/candle-nn/src/layer_norm.rs
@@ -28,7 +28,7 @@
//! ```
//!
//! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450
-use candle::{DType, Result, Tensor, D};
+use candle::{DType, Module, Result, Tensor, D};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct LayerNormConfig {
@@ -105,7 +105,7 @@ impl LayerNorm {
}
}
-impl crate::Module for LayerNorm {
+impl Module for LayerNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
@@ -162,11 +162,20 @@ impl RmsNorm {
pub fn into_inner(self) -> LayerNorm {
self.0
}
+
+ /// Faster variant of the forward kernel, this can only be used on contiguous tensors though.
+ pub fn forward_diff(&self, xs: &Tensor) -> Result<Tensor> {
+ self.0.forward(xs)
+ }
}
-impl crate::Module for RmsNorm {
+impl Module for RmsNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- self.0.forward(xs)
+ if xs.is_contiguous() {
+ crate::ops::rms_norm(xs, &self.0.weight, self.0.eps as f32)
+ } else {
+ self.0.forward(xs)
+ }
}
}