diff options
Diffstat (limited to 'candle-examples/examples/llama/model.rs')
-rw-r--r-- | candle-examples/examples/llama/model.rs | 23 |
1 files changed, 4 insertions, 19 deletions
diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index 751b5902..e0bb70e7 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -152,35 +152,20 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> { } struct RmsNorm { - scale: Tensor, - eps: f64, + inner: candle_nn::LayerNorm, span: tracing::Span, } impl RmsNorm { fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> { let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); - let scale = vb.get(size, "weight")?; - Ok(Self { scale, eps, span }) + let inner = candle_nn::rms_norm(size, eps, vb)?; + Ok(Self { inner, span }) } fn forward(&self, x: &Tensor) -> Result<Tensor> { let _enter = self.span.enter(); - let in_dtype = x.dtype(); - // This is a no-op if x's dtype is already f32. - let x = x.to_dtype(DType::F32)?; - let (b_sz, seq_len, hidden_size) = x.dims3()?; - let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; - let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?; - let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?; - let size = self.scale.dims1()?; - let scale = self - .scale - .to_dtype(DType::F32)? - .broadcast_as((b_sz, seq_len, size))?; - let x = (scale * x_normed)?; - let x = x.to_dtype(in_dtype)?; - Ok(x) + self.inner.forward(x) } } |