diff options
-rw-r--r-- | Cargo.toml | 2 | ||||
-rw-r--r-- | candle-core/src/op.rs | 2 | ||||
-rw-r--r-- | candle-examples/examples/llama_multiprocess/main.rs | 28 | ||||
-rw-r--r-- | candle-examples/examples/llama_multiprocess/model.rs | 66 |
4 files changed, 62 insertions, 36 deletions
@@ -3,7 +3,7 @@ members = [ "candle-core", "candle-examples", "candle-nn", - "candle-pyo3", + # "candle-pyo3", "candle-transformers", "candle-wasm-examples/llama2-c", "candle-wasm-examples/whisper", diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 525383b2..83b382cd 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -103,7 +103,7 @@ pub enum Op { } /// Unary ops that can be defined in user-land. -pub trait CustomOp1: Send + Sync { +pub trait CustomOp1 { // Box<dyn> does not support const yet, so use a function to get the name. fn name(&self) -> &'static str; diff --git a/candle-examples/examples/llama_multiprocess/main.rs b/candle-examples/examples/llama_multiprocess/main.rs index 22c121dd..f9e87432 100644 --- a/candle-examples/examples/llama_multiprocess/main.rs +++ b/candle-examples/examples/llama_multiprocess/main.rs @@ -247,20 +247,24 @@ fn main() -> Result<()> { let next_token = logits_processor.sample(&logits)?; tokens.push(next_token); new_tokens.push(next_token); - println!("> {:?}", start_gen.elapsed()); + if rank == 0 { + println!("> {:?}", start_gen.elapsed()); + println!( + "{} token: {} '{}'", + index + 1, + next_token, + tokenizer.decode(vec![next_token], true).map_err(E::msg)? + ); + } + } + let dt = start_gen.elapsed(); + if rank == 0 { println!( - "{} token: {} '{}'", - index + 1, - next_token, - tokenizer.decode(vec![next_token], true).map_err(E::msg)? + "{} tokens generated ({} token/s)\n----\n{}\n----", + args.sample_len, + args.sample_len as f64 / dt.as_secs_f64(), + tokenizer.decode(new_tokens, true).map_err(E::msg)? ); } - let dt = start_gen.elapsed(); - println!( - "{} tokens generated ({} token/s)\n----\n{}\n----", - args.sample_len, - args.sample_len as f64 / dt.as_secs_f64(), - tokenizer.decode(new_tokens, true).map_err(E::msg)? - ); Ok(()) } diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs index 4e46b526..e902734f 100644 --- a/candle-examples/examples/llama_multiprocess/model.rs +++ b/candle-examples/examples/llama_multiprocess/model.rs @@ -1,6 +1,6 @@ -use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle::backend::BackendStorage; +use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D}; use candle_nn::{Embedding, Linear, VarBuilder}; -use cudarc::driver::safe::CudaSlice; use cudarc::nccl::safe::{Comm, ReduceOp}; use half::f16; use std::collections::HashMap; @@ -27,14 +27,42 @@ struct TensorParallelRowLinear { comm: Rc<Comm>, } +struct AllReduce { + comm: Rc<Comm>, +} + +impl CustomOp1 for AllReduce { + fn name(&self) -> &'static str { + "allreduce" + } + + fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> { + todo!("implement allreduce for cpu is not necessary for single node"); + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s: &candle::CudaStorage, + l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + use candle::cuda_backend::WrapErr; + let elem_count = l.shape().elem_count(); + let dev = s.device().clone(); + let s = s.as_cuda_slice::<f16>()?; + // let s = match l.contiguous_offsets() { + // None => Err(Error::Wrapped("input has to be contiguous".into()))?, + // Some((o1, o2)) => s.slice(o1..o2), + // }; + let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?; + self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap(); + let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev); + Ok((dst, l.shape().clone())) + } +} + fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> { - Ok(x.clone()) - // let n = x.shape().elem_count(); - // let cuda_slice: CudaSlice<f16> = x.try_into()?; - // let dev = cuda_slice.device(); - // let mut slice_receive = dev.alloc_zeros(n).unwrap(); - // comm.all_reduce(cuda_slice, &mut slice_receive, &ReduceOp::Sum).unwrap(); - // Tensor::from_raw_storage(slice_receive, x.shape()) + x.custom_op1(AllReduce { comm: comm.clone() }) } impl TensorParallelRowLinear { @@ -187,11 +215,11 @@ impl RmsNorm { let in_dtype = x.dtype(); // This is a no-op if x's dtype is already f32. let x = x.to_dtype(DType::F32)?; - let (b_sz, seq_len, hidden_size) = x.shape().r3()?; + let (b_sz, seq_len, hidden_size) = x.shape().dims3()?; let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?; let x_normed = (x / (norm_x + 1e-6)?.sqrt()?)?; - let size = self.scale.shape().r1()?; + let size = self.scale.shape().dims1()?; let scale = self .scale .to_dtype(DType::F32)? @@ -213,7 +241,7 @@ struct CausalSelfAttention { impl CausalSelfAttention { fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { - let (b_sz, _, seq_len, n_embd) = x.shape().r4()?; + let (b_sz, _, seq_len, n_embd) = x.shape().dims4()?; let cos = self.cache.cos.narrow(0, index_pos, seq_len)?; let sin = self.cache.sin.narrow(0, index_pos, seq_len)?; let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?; @@ -227,7 +255,7 @@ impl CausalSelfAttention { fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> { let x_dtype = x.dtype(); - let (b_sz, seq_len, _) = x.shape().r3()?; + let (b_sz, seq_len, _) = x.shape().dims3()?; let qkv = self.qkv_proj.forward(x)?; let n_embd = self.n_head * self.head_dim; @@ -302,7 +330,7 @@ impl CausalSelfAttention { if n_rep == 1 { Ok(x) } else { - let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().r4()?; + let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().dims4()?; let x = x .unsqueeze(2)? .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))? @@ -312,10 +340,6 @@ impl CausalSelfAttention { } fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> { - let size_in = cfg.hidden_size; - let size_q = (cfg.hidden_size / cfg.n_head) * cfg.n_head; - let size_kv = (cfg.hidden_size / cfg.n_head) * cfg.n_key_value_head; - let qkv_proj = TensorParallelColumnLinear::load_multi( vb.clone(), &["q_proj", "k_proj", "v_proj"], @@ -364,9 +388,7 @@ impl Mlp { self.c_proj.forward(&x) } - fn load(vb: VarBuilder, cfg: &Config, comm: Rc<Comm>) -> Result<Self> { - let h_size = cfg.hidden_size; - let i_size = cfg.intermediate_size; + fn load(vb: VarBuilder, _cfg: &Config, comm: Rc<Comm>) -> Result<Self> { let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?; let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?; let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm.clone())?; @@ -433,7 +455,7 @@ impl Llama { } pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { - let (_b_sz, seq_len) = x.shape().r2()?; + let (_b_sz, seq_len) = x.shape().dims2()?; let mut x = self.wte.forward(x)?; for (block_idx, block) in self.blocks.iter().enumerate() { x = block.forward(&x, index_pos, block_idx)?; |