diff options
author | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-07-26 11:16:04 +0200 |
---|---|---|
committer | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-07-27 09:58:47 +0200 |
commit | 7c7e6ba201d0270f5ac689c20f16f59e00ed4d01 (patch) | |
tree | 134efa4d30f7dcaac74b1ec858692c50671a5953 /candle-examples/examples/llama_multiprocess/model.rs | |
parent | 1553b58fe59a29fe808b9b4d43a6502046ce26dd (diff) | |
download | candle-7c7e6ba201d0270f5ac689c20f16f59e00ed4d01.tar.gz candle-7c7e6ba201d0270f5ac689c20f16f59e00ed4d01.tar.bz2 candle-7c7e6ba201d0270f5ac689c20f16f59e00ed4d01.zip |
Removing inner dependency on safetensors.
Diffstat (limited to 'candle-examples/examples/llama_multiprocess/model.rs')
-rw-r--r-- | candle-examples/examples/llama_multiprocess/model.rs | 23 |
1 files changed, 11 insertions, 12 deletions
diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs index e902734f..becaa879 100644 --- a/candle-examples/examples/llama_multiprocess/model.rs +++ b/candle-examples/examples/llama_multiprocess/model.rs @@ -4,7 +4,6 @@ use candle_nn::{Embedding, Linear, VarBuilder}; use cudarc::nccl::safe::{Comm, ReduceOp}; use half::f16; use std::collections::HashMap; -use std::rc::Rc; use std::sync::{Arc, Mutex}; use super::MAX_SEQ_LEN; @@ -24,11 +23,11 @@ impl TensorParallelColumnLinear { struct TensorParallelRowLinear { linear: Linear, - comm: Rc<Comm>, + comm: Arc<Comm>, } struct AllReduce { - comm: Rc<Comm>, + comm: Arc<Comm>, } impl CustomOp1 for AllReduce { @@ -61,12 +60,12 @@ impl CustomOp1 for AllReduce { } } -fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> { +fn all_reduce_sum(x: &Tensor, comm: &Arc<Comm>) -> Result<Tensor> { x.custom_op1(AllReduce { comm: comm.clone() }) } impl TensorParallelRowLinear { - fn new(linear: Linear, comm: Rc<Comm>) -> Self { + fn new(linear: Linear, comm: Arc<Comm>) -> Self { Self { linear, comm } } fn forward(&self, x: &Tensor) -> Result<Tensor> { @@ -76,14 +75,14 @@ impl TensorParallelRowLinear { } impl TensorParallelColumnLinear { - fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> { + fn load(vb: VarBuilder, comm: Arc<Comm>) -> Result<Self> { let rank = comm.rank(); let size = comm.world_size(); let weight = vb.get_sharded("weight", 0, rank, size)?; Ok(Self::new(Linear::new(weight, None))) } - fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Rc<Comm>) -> Result<Self> { + fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Arc<Comm>) -> Result<Self> { let rank = comm.rank(); let size = comm.world_size(); let weights: Vec<_> = prefixes @@ -96,7 +95,7 @@ impl TensorParallelColumnLinear { } impl TensorParallelRowLinear { - fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> { + fn load(vb: VarBuilder, comm: Arc<Comm>) -> Result<Self> { let rank = comm.rank(); let size = comm.world_size(); let weight = vb.get_sharded("weight", 1, rank, size)?; @@ -339,7 +338,7 @@ impl CausalSelfAttention { } } - fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> { + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc<Comm>) -> Result<Self> { let qkv_proj = TensorParallelColumnLinear::load_multi( vb.clone(), &["q_proj", "k_proj", "v_proj"], @@ -388,7 +387,7 @@ impl Mlp { self.c_proj.forward(&x) } - fn load(vb: VarBuilder, _cfg: &Config, comm: Rc<Comm>) -> Result<Self> { + fn load(vb: VarBuilder, _cfg: &Config, comm: Arc<Comm>) -> Result<Self> { let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?; let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?; let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm.clone())?; @@ -422,7 +421,7 @@ impl Block { Ok(x) } - fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> { + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc<Comm>) -> Result<Self> { let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?; let mlp = Mlp::load(vb.pp("mlp"), cfg, comm.clone())?; let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?; @@ -466,7 +465,7 @@ impl Llama { logits.to_dtype(DType::F32) } - pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> { + pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc<Comm>) -> Result<Self> { let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?; |