diff options
Diffstat (limited to 'candle-examples/examples')
-rw-r--r-- | candle-examples/examples/llama_multiprocess/model.rs | 23 |
1 files changed, 11 insertions, 12 deletions
diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs index e902734f..becaa879 100644 --- a/candle-examples/examples/llama_multiprocess/model.rs +++ b/candle-examples/examples/llama_multiprocess/model.rs @@ -4,7 +4,6 @@ use candle_nn::{Embedding, Linear, VarBuilder}; use cudarc::nccl::safe::{Comm, ReduceOp}; use half::f16; use std::collections::HashMap; -use std::rc::Rc; use std::sync::{Arc, Mutex}; use super::MAX_SEQ_LEN; @@ -24,11 +23,11 @@ impl TensorParallelColumnLinear { struct TensorParallelRowLinear { linear: Linear, - comm: Rc<Comm>, + comm: Arc<Comm>, } struct AllReduce { - comm: Rc<Comm>, + comm: Arc<Comm>, } impl CustomOp1 for AllReduce { @@ -61,12 +60,12 @@ impl CustomOp1 for AllReduce { } } -fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> { +fn all_reduce_sum(x: &Tensor, comm: &Arc<Comm>) -> Result<Tensor> { x.custom_op1(AllReduce { comm: comm.clone() }) } impl TensorParallelRowLinear { - fn new(linear: Linear, comm: Rc<Comm>) -> Self { + fn new(linear: Linear, comm: Arc<Comm>) -> Self { Self { linear, comm } } fn forward(&self, x: &Tensor) -> Result<Tensor> { @@ -76,14 +75,14 @@ impl TensorParallelRowLinear { } impl TensorParallelColumnLinear { - fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> { + fn load(vb: VarBuilder, comm: Arc<Comm>) -> Result<Self> { let rank = comm.rank(); let size = comm.world_size(); let weight = vb.get_sharded("weight", 0, rank, size)?; Ok(Self::new(Linear::new(weight, None))) } - fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Rc<Comm>) -> Result<Self> { + fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Arc<Comm>) -> Result<Self> { let rank = comm.rank(); let size = comm.world_size(); let weights: Vec<_> = prefixes @@ -96,7 +95,7 @@ impl TensorParallelColumnLinear { } impl TensorParallelRowLinear { - fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> { + fn load(vb: VarBuilder, comm: Arc<Comm>) -> Result<Self> { let rank = comm.rank(); let size = comm.world_size(); let weight = vb.get_sharded("weight", 1, rank, size)?; @@ -339,7 +338,7 @@ impl CausalSelfAttention { } } - fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> { + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc<Comm>) -> Result<Self> { let qkv_proj = TensorParallelColumnLinear::load_multi( vb.clone(), &["q_proj", "k_proj", "v_proj"], @@ -388,7 +387,7 @@ impl Mlp { self.c_proj.forward(&x) } - fn load(vb: VarBuilder, _cfg: &Config, comm: Rc<Comm>) -> Result<Self> { + fn load(vb: VarBuilder, _cfg: &Config, comm: Arc<Comm>) -> Result<Self> { let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?; let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?; let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm.clone())?; @@ -422,7 +421,7 @@ impl Block { Ok(x) } - fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> { + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc<Comm>) -> Result<Self> { let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?; let mlp = Mlp::load(vb.pp("mlp"), cfg, comm.clone())?; let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?; @@ -466,7 +465,7 @@ impl Llama { logits.to_dtype(DType::F32) } - pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> { + pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc<Comm>) -> Result<Self> { let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?; |