summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/llama/model.rs23
-rw-r--r--candle-examples/examples/llama2-c/model.rs42
-rw-r--r--candle-examples/examples/llama_multiprocess/model.rs45
-rw-r--r--candle-examples/examples/quantized/main.rs22
-rw-r--r--candle-nn/src/layer_norm.rs104
-rw-r--r--candle-nn/src/lib.rs2
-rw-r--r--candle-wasm-examples/llama2-c/src/model.rs44
7 files changed, 124 insertions, 158 deletions
diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs
index 751b5902..e0bb70e7 100644
--- a/candle-examples/examples/llama/model.rs
+++ b/candle-examples/examples/llama/model.rs
@@ -152,35 +152,20 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
}
struct RmsNorm {
- scale: Tensor,
- eps: f64,
+ inner: candle_nn::LayerNorm,
span: tracing::Span,
}
impl RmsNorm {
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
- let scale = vb.get(size, "weight")?;
- Ok(Self { scale, eps, span })
+ let inner = candle_nn::rms_norm(size, eps, vb)?;
+ Ok(Self { inner, span })
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
- let in_dtype = x.dtype();
- // This is a no-op if x's dtype is already f32.
- let x = x.to_dtype(DType::F32)?;
- let (b_sz, seq_len, hidden_size) = x.dims3()?;
- let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
- let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
- let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
- let size = self.scale.dims1()?;
- let scale = self
- .scale
- .to_dtype(DType::F32)?
- .broadcast_as((b_sz, seq_len, size))?;
- let x = (scale * x_normed)?;
- let x = x.to_dtype(in_dtype)?;
- Ok(x)
+ self.inner.forward(x)
}
}
diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-examples/examples/llama2-c/model.rs
index 77900d27..75269665 100644
--- a/candle-examples/examples/llama2-c/model.rs
+++ b/candle-examples/examples/llama2-c/model.rs
@@ -1,6 +1,6 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::linear_no_bias as linear;
-use candle_nn::{embedding, Embedding, Linear, VarBuilder};
+use candle_nn::{embedding, rms_norm, Embedding, LayerNorm, Linear, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@@ -94,32 +94,6 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)?
}
-struct RmsNorm {
- scale: Tensor,
- eps: f64,
-}
-
-impl RmsNorm {
- fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
- let scale = vb.get_or_init(size, "weight", candle_nn::Init::Const(1.))?;
- Ok(Self { scale, eps })
- }
-
- fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let (b_sz, seq_len, hidden_size) = x.dims3()?;
- let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
- let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
- let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
- let size = self.scale.dims1()?;
- let scale = self
- .scale
- .to_dtype(DType::F32)?
- .broadcast_as((b_sz, seq_len, size))?;
- let x = (scale * x_normed)?;
- Ok(x)
- }
-}
-
struct CausalSelfAttention {
q_proj: Linear,
k_proj: Linear,
@@ -262,14 +236,14 @@ impl Mlp {
}
struct Block {
- rms_1: RmsNorm,
+ rms_1: LayerNorm,
attn: CausalSelfAttention,
- rms_2: RmsNorm,
+ rms_2: LayerNorm,
mlp: Mlp,
}
impl Block {
- fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
+ fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self {
Self {
rms_1,
attn,
@@ -290,9 +264,9 @@ impl Block {
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
- let input_layernorm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
+ let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm =
- RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?;
+ rms_norm(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?;
Ok(Self::new(
input_layernorm,
attn,
@@ -305,7 +279,7 @@ impl Block {
pub struct Llama {
wte: Embedding,
blocks: Vec<Block>,
- ln_f: RmsNorm,
+ ln_f: LayerNorm,
lm_head: Linear,
pub config: Config,
}
@@ -325,7 +299,7 @@ impl Llama {
pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
- let ln_f = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
+ let ln_f = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layers)
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, &cfg).unwrap())
.collect();
diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs
index 348248f6..ab4e382c 100644
--- a/candle-examples/examples/llama_multiprocess/model.rs
+++ b/candle-examples/examples/llama_multiprocess/model.rs
@@ -1,6 +1,6 @@
use candle::backend::BackendStorage;
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
-use candle_nn::{Embedding, Linear, VarBuilder};
+use candle_nn::{rms_norm, Embedding, LayerNorm, Linear, VarBuilder};
use cudarc::nccl::safe::{Comm, ReduceOp};
use half::f16;
use std::rc::Rc;
@@ -182,39 +182,6 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
Ok(Embedding::new(embeddings, cfg.hidden_size))
}
-struct RmsNorm {
- scale: Tensor,
-}
-
-impl RmsNorm {
- fn load(size: usize, vb: VarBuilder) -> Result<Self> {
- let scale = vb.get(size, "weight")?;
- Ok(Self::new(scale))
- }
-
- fn new(scale: Tensor) -> Self {
- Self { scale }
- }
-
- fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let in_dtype = x.dtype();
- // This is a no-op if x's dtype is already f32.
- let x = x.to_dtype(DType::F32)?;
- let (b_sz, seq_len, hidden_size) = x.shape().dims3()?;
- let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
- let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
- let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
- let size = self.scale.shape().dims1()?;
- let scale = self
- .scale
- .to_dtype(DType::F32)?
- .broadcast_as((b_sz, seq_len, size))?;
- let x = (scale * x_normed)?;
- let x = x.to_dtype(in_dtype)?;
- Ok(x)
- }
-}
-
struct CausalSelfAttention {
qkv_proj: TensorParallelColumnLinear,
o_proj: TensorParallelRowLinear,
@@ -369,14 +336,14 @@ impl Mlp {
}
struct Block {
- rms_1: RmsNorm,
+ rms_1: LayerNorm,
attn: CausalSelfAttention,
- rms_2: RmsNorm,
+ rms_2: LayerNorm,
mlp: Mlp,
}
impl Block {
- fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
+ fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self {
Self {
rms_1,
attn,
@@ -397,9 +364,9 @@ impl Block {
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?;
let mlp = Mlp::load(vb.pp("mlp"), cfg, comm)?;
- let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?;
+ let input_layernorm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("input_layernorm"))?;
let post_attention_layernorm =
- RmsNorm::load(cfg.hidden_size, vb.pp("post_attention_layernorm"))?;
+ rms_norm(cfg.hidden_size, 1e-5, vb.pp("post_attention_layernorm"))?;
Ok(Self::new(
input_layernorm,
attn,
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs
index f42d6f0f..94efb03f 100644
--- a/candle-examples/examples/quantized/main.rs
+++ b/candle-examples/examples/quantized/main.rs
@@ -14,8 +14,7 @@ const MAX_SEQ_LEN: usize = 4096;
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
struct RmsNorm {
- scale: Tensor,
- eps: f64,
+ inner: candle_nn::LayerNorm,
span: tracing::Span,
}
@@ -23,26 +22,13 @@ impl RmsNorm {
fn new(scale: QTensor) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let scale = scale.dequantize(&Device::Cpu)?;
- Ok(Self {
- scale,
- eps: 1e-5,
- span,
- })
+ let inner = candle_nn::LayerNorm::rms_norm(scale, 1e-5);
+ Ok(Self { inner, span })
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
- let (b_sz, seq_len, hidden_size) = x.dims3()?;
- let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
- let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
- let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
- let size = self.scale.dims1()?;
- let scale = self
- .scale
- .to_dtype(DType::F32)?
- .broadcast_as((b_sz, seq_len, size))?;
- let x = (scale * x_normed)?;
- Ok(x)
+ self.inner.forward(x)
}
}
diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs
index 668f9a4b..f9892a2c 100644
--- a/candle-nn/src/layer_norm.rs
+++ b/candle-nn/src/layer_norm.rs
@@ -30,17 +30,70 @@
//! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450
use candle::{DType, Result, Tensor};
+#[derive(Debug, Clone, Copy, PartialEq)]
+pub struct LayerNormConfig {
+ pub eps: f64,
+ /// Whether to remove the mean or not, the default is true and when set to false, this turns
+ /// this layer into RmsNorm.
+ pub remove_mean: bool,
+ pub affine: bool,
+}
+
+impl Default for LayerNormConfig {
+ fn default() -> Self {
+ Self {
+ eps: 1e-5,
+ remove_mean: true,
+ affine: true,
+ }
+ }
+}
+
+impl From<f64> for LayerNormConfig {
+ fn from(eps: f64) -> Self {
+ Self {
+ eps,
+ remove_mean: true,
+ affine: true,
+ }
+ }
+}
+
// This layer norm version handles both weight and bias so removes the mean.
#[derive(Debug)]
pub struct LayerNorm {
weight: Tensor,
- bias: Tensor,
+ bias: Option<Tensor>,
+ remove_mean: bool,
eps: f64,
}
impl LayerNorm {
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
- Self { weight, bias, eps }
+ Self {
+ weight,
+ bias: Some(bias),
+ remove_mean: true,
+ eps,
+ }
+ }
+
+ pub fn new_no_bias(weight: Tensor, eps: f64) -> Self {
+ Self {
+ weight,
+ bias: None,
+ remove_mean: true,
+ eps,
+ }
+ }
+
+ pub fn rms_norm(weight: Tensor, eps: f64) -> Self {
+ Self {
+ weight,
+ bias: None,
+ remove_mean: false,
+ eps,
+ }
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
@@ -51,20 +104,47 @@ impl LayerNorm {
};
let (_bsize, _seq_len, hidden_size) = x.dims3()?;
let x = x.to_dtype(internal_dtype)?;
- let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
- let x = x.broadcast_sub(&mean_x)?;
+ let x = if self.remove_mean {
+ let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
+ x.broadcast_sub(&mean_x)?
+ } else {
+ x
+ };
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
- let x = x_normed
- .to_dtype(x_dtype)?
- .broadcast_mul(&self.weight)?
- .broadcast_add(&self.bias)?;
- Ok(x)
+ let x = x_normed.to_dtype(x_dtype)?.broadcast_mul(&self.weight)?;
+ match &self.bias {
+ None => Ok(x),
+ Some(bias) => x.broadcast_add(bias),
+ }
}
}
-pub fn layer_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result<LayerNorm> {
+pub fn layer_norm<C: Into<LayerNormConfig>>(
+ size: usize,
+ config: C,
+ vb: crate::VarBuilder,
+) -> Result<LayerNorm> {
+ let config = config.into();
let weight = vb.get_or_init(size, "weight", crate::Init::Const(1.))?;
- let bias = vb.get_or_init(size, "bias", crate::Init::Const(0.))?;
- Ok(LayerNorm::new(weight, bias, eps))
+ let bias = if config.affine {
+ Some(vb.get_or_init(size, "bias", crate::Init::Const(0.))?)
+ } else {
+ None
+ };
+ Ok(LayerNorm {
+ weight,
+ bias,
+ remove_mean: config.remove_mean,
+ eps: config.eps,
+ })
+}
+
+pub fn rms_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result<LayerNorm> {
+ let config = LayerNormConfig {
+ eps,
+ remove_mean: false,
+ affine: false,
+ };
+ layer_norm(size, config, vb)
}
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs
index ae955f56..05464ceb 100644
--- a/candle-nn/src/lib.rs
+++ b/candle-nn/src/lib.rs
@@ -17,7 +17,7 @@ pub use conv::{conv1d, conv2d, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig};
pub use embedding::{embedding, Embedding};
pub use group_norm::{group_norm, GroupNorm};
pub use init::Init;
-pub use layer_norm::{layer_norm, LayerNorm};
+pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig};
pub use linear::{linear, linear_no_bias, Linear};
pub use optim::{AdamW, ParamsAdamW, SGD};
pub use var_builder::{VarBuilder, VarMap};
diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs
index 3231cabf..d2b787ae 100644
--- a/candle-wasm-examples/llama2-c/src/model.rs
+++ b/candle-wasm-examples/llama2-c/src/model.rs
@@ -1,5 +1,5 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
-use candle_nn::{Embedding, Linear, VarBuilder};
+use candle_nn::{rms_norm, Embedding, LayerNorm, Linear, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@@ -71,32 +71,6 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
Ok(Embedding::new(embeddings, cfg.dim))
}
-struct RmsNorm {
- scale: Tensor,
- eps: f64,
-}
-
-impl RmsNorm {
- fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
- let scale = vb.get(size, "weight")?;
- Ok(Self { scale, eps })
- }
-
- fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let (b_sz, seq_len, hidden_size) = x.dims3()?;
- let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
- let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
- let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
- let size = self.scale.dims1()?;
- let scale = self
- .scale
- .to_dtype(DType::F32)?
- .broadcast_as((b_sz, seq_len, size))?;
- let x = (scale * x_normed)?;
- Ok(x)
- }
-}
-
struct CausalSelfAttention {
q_proj: Linear,
k_proj: Linear,
@@ -239,14 +213,14 @@ impl Mlp {
}
struct Block {
- rms_1: RmsNorm,
+ rms_1: LayerNorm,
attn: CausalSelfAttention,
- rms_2: RmsNorm,
+ rms_2: LayerNorm,
mlp: Mlp,
}
impl Block {
- fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
+ fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self {
Self {
rms_1,
attn,
@@ -267,9 +241,9 @@ impl Block {
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
- let input_layernorm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
+ let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm =
- RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?;
+ rms_norm(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?;
Ok(Self::new(
input_layernorm,
attn,
@@ -282,12 +256,12 @@ impl Block {
pub struct Llama {
wte: Embedding,
blocks: Vec<Block>,
- ln_f: RmsNorm,
+ ln_f: LayerNorm,
lm_head: Linear,
}
impl Llama {
- fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self {
+ fn new(wte: Embedding, blocks: Vec<Block>, ln_f: LayerNorm, lm_head: Linear) -> Self {
Self {
wte,
blocks,
@@ -311,7 +285,7 @@ impl Llama {
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
- let norm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
+ let norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layers)
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap())
.collect();