diff options
Diffstat (limited to 'candle-examples/examples/quantized')
-rw-r--r-- | candle-examples/examples/quantized/main.rs | 22 |
1 files changed, 4 insertions, 18 deletions
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index f42d6f0f..94efb03f 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -14,8 +14,7 @@ const MAX_SEQ_LEN: usize = 4096; const DEFAULT_PROMPT: &str = "My favorite theorem is "; struct RmsNorm { - scale: Tensor, - eps: f64, + inner: candle_nn::LayerNorm, span: tracing::Span, } @@ -23,26 +22,13 @@ impl RmsNorm { fn new(scale: QTensor) -> Result<Self> { let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); let scale = scale.dequantize(&Device::Cpu)?; - Ok(Self { - scale, - eps: 1e-5, - span, - }) + let inner = candle_nn::LayerNorm::rms_norm(scale, 1e-5); + Ok(Self { inner, span }) } fn forward(&self, x: &Tensor) -> Result<Tensor> { let _enter = self.span.enter(); - let (b_sz, seq_len, hidden_size) = x.dims3()?; - let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; - let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?; - let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?; - let size = self.scale.dims1()?; - let scale = self - .scale - .to_dtype(DType::F32)? - .broadcast_as((b_sz, seq_len, size))?; - let x = (scale * x_normed)?; - Ok(x) + self.inner.forward(x) } } |