summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama/model.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/llama/model.rs')
-rw-r--r--candle-examples/examples/llama/model.rs23
1 files changed, 4 insertions, 19 deletions
diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs
index 751b5902..e0bb70e7 100644
--- a/candle-examples/examples/llama/model.rs
+++ b/candle-examples/examples/llama/model.rs
@@ -152,35 +152,20 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
}
struct RmsNorm {
- scale: Tensor,
- eps: f64,
+ inner: candle_nn::LayerNorm,
span: tracing::Span,
}
impl RmsNorm {
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
- let scale = vb.get(size, "weight")?;
- Ok(Self { scale, eps, span })
+ let inner = candle_nn::rms_norm(size, eps, vb)?;
+ Ok(Self { inner, span })
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
- let in_dtype = x.dtype();
- // This is a no-op if x's dtype is already f32.
- let x = x.to_dtype(DType::F32)?;
- let (b_sz, seq_len, hidden_size) = x.dims3()?;
- let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
- let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
- let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
- let size = self.scale.dims1()?;
- let scale = self
- .scale
- .to_dtype(DType::F32)?
- .broadcast_as((b_sz, seq_len, size))?;
- let x = (scale * x_normed)?;
- let x = x.to_dtype(in_dtype)?;
- Ok(x)
+ self.inner.forward(x)
}
}