//! Utilities for quanitized network layers //! //! This module contains various implementations of standard neural network layers, modules and //! utilities including embedding, linear layers, and various normalization techniques. //! Most implementations provide quantized weights support. use crate::models::with_tracing::QMatMul; use crate::quantized_var_builder::VarBuilder; use candle::quantized::QTensor; use candle::{Module, Result, Tensor}; #[derive(Debug, Clone)] pub struct Embedding { inner: candle_nn::Embedding, span: tracing::Span, } impl Embedding { pub fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result { let embeddings = vb.get((d1, d2), "weight")?.dequantize(vb.device())?; let inner = candle_nn::Embedding::new(embeddings, d2); let span = tracing::span!(tracing::Level::TRACE, "embedding"); Ok(Self { inner, span }) } pub fn embeddings(&self) -> &Tensor { self.inner.embeddings() } } impl Module for Embedding { fn forward(&self, xs: &Tensor) -> Result { let _enter = self.span.enter(); self.inner.forward(xs) } } #[derive(Debug, Clone)] pub struct Linear { weight: QMatMul, bias: Option, } impl Linear { pub fn from_arc(weight: std::sync::Arc, bias: Option) -> Result { let weight = QMatMul::from_weights(weight)?; Ok(Self { weight, bias }) } pub fn from_weights(weight: QMatMul, bias: Option) -> Self { Self { weight, bias } } } impl Module for Linear { fn forward(&self, x: &Tensor) -> candle::Result { let x = x.apply(&self.weight)?; match &self.bias { None => Ok(x), Some(bias) => x.broadcast_add(bias), } } } pub fn linear_b(in_dim: usize, out_dim: usize, bias: bool, vb: VarBuilder) -> Result { let bias = if bias { Some(vb.get(out_dim, "bias")?.dequantize(vb.device())?) } else { None }; let weight = QMatMul::new(in_dim, out_dim, vb)?; Ok(Linear { weight, bias }) } pub fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result { let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?; let weight = QMatMul::new(in_dim, out_dim, vb)?; Ok(Linear { weight, bias: Some(bias), }) } pub fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { let weight = vb.get(size, "weight")?.dequantize(vb.device())?; let bias = vb.get(size, "bias")?.dequantize(vb.device())?; Ok(candle_nn::LayerNorm::new(weight, bias, eps)) } pub fn layer_norm_no_bias(size: usize, eps: f64, vb: VarBuilder) -> Result { let weight = vb.get(size, "weight")?.dequantize(vb.device())?; Ok(candle_nn::LayerNorm::new_no_bias(weight, eps)) } pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result { let weight = QMatMul::new(in_dim, out_dim, vb)?; Ok(Linear { weight, bias: None }) } #[derive(Debug, Clone)] pub struct RmsNorm { weight: Tensor, eps: f64, span: tracing::Span, } impl RmsNorm { pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result { let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); let weight = vb.get(size, "weight")?.dequantize(vb.device())?; Ok(Self { weight, eps, span }) } pub fn from_qtensor(weight: QTensor, eps: f64) -> Result { let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); let weight = weight.dequantize(&weight.device())?; Ok(Self { weight, eps, span }) } } impl Module for RmsNorm { fn forward(&self, x: &Tensor) -> Result { let _enter = self.span.enter(); candle_nn::ops::rms_norm(x, &self.weight, self.eps as f32) } }