diff options
Diffstat (limited to 'candle-examples/examples/llama_multiprocess/model.rs')
-rw-r--r-- | candle-examples/examples/llama_multiprocess/model.rs | 45 |
1 files changed, 6 insertions, 39 deletions
diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs index 348248f6..ab4e382c 100644 --- a/candle-examples/examples/llama_multiprocess/model.rs +++ b/candle-examples/examples/llama_multiprocess/model.rs @@ -1,6 +1,6 @@ use candle::backend::BackendStorage; use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D}; -use candle_nn::{Embedding, Linear, VarBuilder}; +use candle_nn::{rms_norm, Embedding, LayerNorm, Linear, VarBuilder}; use cudarc::nccl::safe::{Comm, ReduceOp}; use half::f16; use std::rc::Rc; @@ -182,39 +182,6 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> { Ok(Embedding::new(embeddings, cfg.hidden_size)) } -struct RmsNorm { - scale: Tensor, -} - -impl RmsNorm { - fn load(size: usize, vb: VarBuilder) -> Result<Self> { - let scale = vb.get(size, "weight")?; - Ok(Self::new(scale)) - } - - fn new(scale: Tensor) -> Self { - Self { scale } - } - - fn forward(&self, x: &Tensor) -> Result<Tensor> { - let in_dtype = x.dtype(); - // This is a no-op if x's dtype is already f32. - let x = x.to_dtype(DType::F32)?; - let (b_sz, seq_len, hidden_size) = x.shape().dims3()?; - let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; - let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?; - let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?; - let size = self.scale.shape().dims1()?; - let scale = self - .scale - .to_dtype(DType::F32)? - .broadcast_as((b_sz, seq_len, size))?; - let x = (scale * x_normed)?; - let x = x.to_dtype(in_dtype)?; - Ok(x) - } -} - struct CausalSelfAttention { qkv_proj: TensorParallelColumnLinear, o_proj: TensorParallelRowLinear, @@ -369,14 +336,14 @@ impl Mlp { } struct Block { - rms_1: RmsNorm, + rms_1: LayerNorm, attn: CausalSelfAttention, - rms_2: RmsNorm, + rms_2: LayerNorm, mlp: Mlp, } impl Block { - fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { + fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self { Self { rms_1, attn, @@ -397,9 +364,9 @@ impl Block { fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> { let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?; let mlp = Mlp::load(vb.pp("mlp"), cfg, comm)?; - let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?; + let input_layernorm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("input_layernorm"))?; let post_attention_layernorm = - RmsNorm::load(cfg.hidden_size, vb.pp("post_attention_layernorm"))?; + rms_norm(cfg.hidden_size, 1e-5, vb.pp("post_attention_layernorm"))?; Ok(Self::new( input_layernorm, attn, |