summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/yi.rs
diff options
context:
space:
mode:
authorJani Monoses <jani.monoses@gmail.com>2024-03-18 22:40:06 +0200
committerGitHub <noreply@github.com>2024-03-18 21:40:06 +0100
commit90fc82211f29282991afc8fea33c78169e674db1 (patch)
tree8b268039018f2e2287a4446076f8f9d0d589303a /candle-transformers/src/models/yi.rs
parent6a966cf9e0abee128f0b8b60f0063bfe5fdaff92 (diff)
downloadcandle-90fc82211f29282991afc8fea33c78169e674db1.tar.gz
candle-90fc82211f29282991afc8fea33c78169e674db1.tar.bz2
candle-90fc82211f29282991afc8fea33c78169e674db1.zip
Use a common with_tracing::RmsNorm in a few models. (#1871)
* Add RmsNorm with tracing. * Use with_tracing::RmsNorm in some models.
Diffstat (limited to 'candle-transformers/src/models/yi.rs')
-rw-r--r--candle-transformers/src/models/yi.rs23
1 files changed, 1 insertions, 22 deletions
diff --git a/candle-transformers/src/models/yi.rs b/candle-transformers/src/models/yi.rs
index 14b6feeb..99d9de1b 100644
--- a/candle-transformers/src/models/yi.rs
+++ b/candle-transformers/src/models/yi.rs
@@ -1,5 +1,5 @@
/// https://huggingface.co/01-ai/Yi-6B/blob/main/modeling_yi.py
-use crate::models::with_tracing::{linear_no_bias, Linear};
+use crate::models::with_tracing::{linear_no_bias, Linear, RmsNorm};
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use std::sync::Arc;
@@ -51,27 +51,6 @@ impl Config {
}
#[derive(Debug, Clone)]
-struct RmsNorm {
- inner: candle_nn::RmsNorm,
- span: tracing::Span,
-}
-
-impl RmsNorm {
- fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
- let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
- let inner = candle_nn::rms_norm(size, eps, vb)?;
- Ok(Self { inner, span })
- }
-}
-
-impl Module for RmsNorm {
- fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- self.inner.forward(x)
- }
-}
-
-#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,