summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama_multiprocess/model.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/llama_multiprocess/model.rs')
-rw-r--r--candle-examples/examples/llama_multiprocess/model.rs45
1 files changed, 6 insertions, 39 deletions
diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs
index 348248f6..ab4e382c 100644
--- a/candle-examples/examples/llama_multiprocess/model.rs
+++ b/candle-examples/examples/llama_multiprocess/model.rs
@@ -1,6 +1,6 @@
use candle::backend::BackendStorage;
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
-use candle_nn::{Embedding, Linear, VarBuilder};
+use candle_nn::{rms_norm, Embedding, LayerNorm, Linear, VarBuilder};
use cudarc::nccl::safe::{Comm, ReduceOp};
use half::f16;
use std::rc::Rc;
@@ -182,39 +182,6 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
Ok(Embedding::new(embeddings, cfg.hidden_size))
}
-struct RmsNorm {
- scale: Tensor,
-}
-
-impl RmsNorm {
- fn load(size: usize, vb: VarBuilder) -> Result<Self> {
- let scale = vb.get(size, "weight")?;
- Ok(Self::new(scale))
- }
-
- fn new(scale: Tensor) -> Self {
- Self { scale }
- }
-
- fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let in_dtype = x.dtype();
- // This is a no-op if x's dtype is already f32.
- let x = x.to_dtype(DType::F32)?;
- let (b_sz, seq_len, hidden_size) = x.shape().dims3()?;
- let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
- let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
- let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
- let size = self.scale.shape().dims1()?;
- let scale = self
- .scale
- .to_dtype(DType::F32)?
- .broadcast_as((b_sz, seq_len, size))?;
- let x = (scale * x_normed)?;
- let x = x.to_dtype(in_dtype)?;
- Ok(x)
- }
-}
-
struct CausalSelfAttention {
qkv_proj: TensorParallelColumnLinear,
o_proj: TensorParallelRowLinear,
@@ -369,14 +336,14 @@ impl Mlp {
}
struct Block {
- rms_1: RmsNorm,
+ rms_1: LayerNorm,
attn: CausalSelfAttention,
- rms_2: RmsNorm,
+ rms_2: LayerNorm,
mlp: Mlp,
}
impl Block {
- fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
+ fn new(rms_1: LayerNorm, attn: CausalSelfAttention, rms_2: LayerNorm, mlp: Mlp) -> Self {
Self {
rms_1,
attn,
@@ -397,9 +364,9 @@ impl Block {
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?;
let mlp = Mlp::load(vb.pp("mlp"), cfg, comm)?;
- let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?;
+ let input_layernorm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("input_layernorm"))?;
let post_attention_layernorm =
- RmsNorm::load(cfg.hidden_size, vb.pp("post_attention_layernorm"))?;
+ rms_norm(cfg.hidden_size, 1e-5, vb.pp("post_attention_layernorm"))?;
Ok(Self::new(
input_layernorm,
attn,