diff options
author | Jani Monoses <jani.monoses@gmail.com> | 2024-03-18 22:40:06 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-18 21:40:06 +0100 |
commit | 90fc82211f29282991afc8fea33c78169e674db1 (patch) | |
tree | 8b268039018f2e2287a4446076f8f9d0d589303a /candle-transformers/src/models/yi.rs | |
parent | 6a966cf9e0abee128f0b8b60f0063bfe5fdaff92 (diff) | |
download | candle-90fc82211f29282991afc8fea33c78169e674db1.tar.gz candle-90fc82211f29282991afc8fea33c78169e674db1.tar.bz2 candle-90fc82211f29282991afc8fea33c78169e674db1.zip |
Use a common with_tracing::RmsNorm in a few models. (#1871)
* Add RmsNorm with tracing.
* Use with_tracing::RmsNorm in some models.
Diffstat (limited to 'candle-transformers/src/models/yi.rs')
-rw-r--r-- | candle-transformers/src/models/yi.rs | 23 |
1 files changed, 1 insertions, 22 deletions
diff --git a/candle-transformers/src/models/yi.rs b/candle-transformers/src/models/yi.rs index 14b6feeb..99d9de1b 100644 --- a/candle-transformers/src/models/yi.rs +++ b/candle-transformers/src/models/yi.rs @@ -1,5 +1,5 @@ /// https://huggingface.co/01-ai/Yi-6B/blob/main/modeling_yi.py -use crate::models::with_tracing::{linear_no_bias, Linear}; +use crate::models::with_tracing::{linear_no_bias, Linear, RmsNorm}; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; use std::sync::Arc; @@ -51,27 +51,6 @@ impl Config { } #[derive(Debug, Clone)] -struct RmsNorm { - inner: candle_nn::RmsNorm, - span: tracing::Span, -} - -impl RmsNorm { - fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> { - let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); - let inner = candle_nn::rms_norm(size, eps, vb)?; - Ok(Self { inner, span }) - } -} - -impl Module for RmsNorm { - fn forward(&self, x: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - -#[derive(Debug, Clone)] struct RotaryEmbedding { sin: Tensor, cos: Tensor, |