diff options
author | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-07-26 10:22:40 +0000 |
---|---|---|
committer | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-07-27 09:58:47 +0200 |
commit | 25a2086e8f4cc23fada32a44607d3b8550916ebe (patch) | |
tree | fa5d6c4fdd61f582cd731875b88cdce19d206fac | |
parent | 7c7e6ba201d0270f5ac689c20f16f59e00ed4d01 (diff) | |
download | candle-25a2086e8f4cc23fada32a44607d3b8550916ebe.tar.gz candle-25a2086e8f4cc23fada32a44607d3b8550916ebe.tar.bz2 candle-25a2086e8f4cc23fada32a44607d3b8550916ebe.zip |
Putting back Send + Sync
-rw-r--r-- | candle-core/src/op.rs | 2 | ||||
-rw-r--r-- | candle-examples/examples/llama_multiprocess/model.rs | 30 | ||||
-rw-r--r-- | candle-pyo3/src/lib.rs | 2 |
3 files changed, 21 insertions, 13 deletions
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 83b382cd..525383b2 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -103,7 +103,7 @@ pub enum Op { } /// Unary ops that can be defined in user-land. -pub trait CustomOp1 { +pub trait CustomOp1: Send + Sync { // Box<dyn> does not support const yet, so use a function to get the name. fn name(&self) -> &'static str; diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs index becaa879..bcf6ed2b 100644 --- a/candle-examples/examples/llama_multiprocess/model.rs +++ b/candle-examples/examples/llama_multiprocess/model.rs @@ -4,6 +4,7 @@ use candle_nn::{Embedding, Linear, VarBuilder}; use cudarc::nccl::safe::{Comm, ReduceOp}; use half::f16; use std::collections::HashMap; +use std::rc::Rc; use std::sync::{Arc, Mutex}; use super::MAX_SEQ_LEN; @@ -23,13 +24,20 @@ impl TensorParallelColumnLinear { struct TensorParallelRowLinear { linear: Linear, - comm: Arc<Comm>, + comm: Rc<Comm>, } struct AllReduce { - comm: Arc<Comm>, + comm: Rc<Comm>, } +/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html +/// But for this example purposes, this will work +unsafe impl Sync for AllReduce {} +/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html +/// But for this example purposes, this will work +unsafe impl Send for AllReduce {} + impl CustomOp1 for AllReduce { fn name(&self) -> &'static str { "allreduce" @@ -60,12 +68,12 @@ impl CustomOp1 for AllReduce { } } -fn all_reduce_sum(x: &Tensor, comm: &Arc<Comm>) -> Result<Tensor> { +fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> { x.custom_op1(AllReduce { comm: comm.clone() }) } impl TensorParallelRowLinear { - fn new(linear: Linear, comm: Arc<Comm>) -> Self { + fn new(linear: Linear, comm: Rc<Comm>) -> Self { Self { linear, comm } } fn forward(&self, x: &Tensor) -> Result<Tensor> { @@ -75,14 +83,14 @@ impl TensorParallelRowLinear { } impl TensorParallelColumnLinear { - fn load(vb: VarBuilder, comm: Arc<Comm>) -> Result<Self> { + fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> { let rank = comm.rank(); let size = comm.world_size(); let weight = vb.get_sharded("weight", 0, rank, size)?; Ok(Self::new(Linear::new(weight, None))) } - fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Arc<Comm>) -> Result<Self> { + fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Rc<Comm>) -> Result<Self> { let rank = comm.rank(); let size = comm.world_size(); let weights: Vec<_> = prefixes @@ -95,7 +103,7 @@ impl TensorParallelColumnLinear { } impl TensorParallelRowLinear { - fn load(vb: VarBuilder, comm: Arc<Comm>) -> Result<Self> { + fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> { let rank = comm.rank(); let size = comm.world_size(); let weight = vb.get_sharded("weight", 1, rank, size)?; @@ -338,7 +346,7 @@ impl CausalSelfAttention { } } - fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc<Comm>) -> Result<Self> { + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> { let qkv_proj = TensorParallelColumnLinear::load_multi( vb.clone(), &["q_proj", "k_proj", "v_proj"], @@ -387,7 +395,7 @@ impl Mlp { self.c_proj.forward(&x) } - fn load(vb: VarBuilder, _cfg: &Config, comm: Arc<Comm>) -> Result<Self> { + fn load(vb: VarBuilder, _cfg: &Config, comm: Rc<Comm>) -> Result<Self> { let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?; let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?; let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm.clone())?; @@ -421,7 +429,7 @@ impl Block { Ok(x) } - fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc<Comm>) -> Result<Self> { + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> { let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?; let mlp = Mlp::load(vb.pp("mlp"), cfg, comm.clone())?; let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?; @@ -465,7 +473,7 @@ impl Llama { logits.to_dtype(DType::F32) } - pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc<Comm>) -> Result<Self> { + pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> { let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?; diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 6e206688..136f8a4f 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -11,7 +11,7 @@ pub fn wrap_err(err: ::candle::Error) -> PyErr { } #[derive(Clone)] -#[pyclass(name = "Tensor", unsendable)] +#[pyclass(name = "Tensor")] struct PyTensor(Tensor); impl std::ops::Deref for PyTensor { |