summaryrefslogtreecommitdiff
path: root/candle-examples/examples/quantized/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/quantized/main.rs')
-rw-r--r--candle-examples/examples/quantized/main.rs22
1 files changed, 4 insertions, 18 deletions
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs
index f42d6f0f..94efb03f 100644
--- a/candle-examples/examples/quantized/main.rs
+++ b/candle-examples/examples/quantized/main.rs
@@ -14,8 +14,7 @@ const MAX_SEQ_LEN: usize = 4096;
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
struct RmsNorm {
- scale: Tensor,
- eps: f64,
+ inner: candle_nn::LayerNorm,
span: tracing::Span,
}
@@ -23,26 +22,13 @@ impl RmsNorm {
fn new(scale: QTensor) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let scale = scale.dequantize(&Device::Cpu)?;
- Ok(Self {
- scale,
- eps: 1e-5,
- span,
- })
+ let inner = candle_nn::LayerNorm::rms_norm(scale, 1e-5);
+ Ok(Self { inner, span })
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
- let (b_sz, seq_len, hidden_size) = x.dims3()?;
- let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
- let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
- let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
- let size = self.scale.dims1()?;
- let scale = self
- .scale
- .to_dtype(DType::F32)?
- .broadcast_as((b_sz, seq_len, size))?;
- let x = (scale * x_normed)?;
- Ok(x)
+ self.inner.forward(x)
}
}