diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-20 22:19:46 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-20 22:19:46 +0200 |
commit | 587ee3bb6fd2b4c2b7bbe7e97751cac96249dd6d (patch) | |
tree | 122db3c84eccdef1dd0451b0c939104ab03a4113 | |
parent | dd78422701e9c6f3ca74218e8aedcf032c6c7215 (diff) | |
download | candle-587ee3bb6fd2b4c2b7bbe7e97751cac96249dd6d.tar.gz candle-587ee3bb6fd2b4c2b7bbe7e97751cac96249dd6d.tar.bz2 candle-587ee3bb6fd2b4c2b7bbe7e97751cac96249dd6d.zip |
Small cleanups to the llama multi-process example. (#2098)
-rw-r--r-- | candle-core/src/error.rs | 6 | ||||
-rw-r--r-- | candle-examples/examples/llama_multiprocess/main.rs | 6 | ||||
-rw-r--r-- | candle-examples/examples/llama_multiprocess/model.rs | 104 | ||||
-rw-r--r-- | candle-transformers/src/models/llama.rs | 8 |
4 files changed, 54 insertions, 70 deletions
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 60ddea11..e7112e2e 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -219,10 +219,14 @@ impl Error { Self::Wrapped(Box::new(err)).bt() } - pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self { + pub fn msg(err: impl std::error::Error) -> Self { Self::Msg(err.to_string()).bt() } + pub fn debug(err: impl std::fmt::Debug) -> Self { + Self::Msg(format!("{err:?}")).bt() + } + pub fn bt(self) -> Self { let backtrace = std::backtrace::Backtrace::capture(); match backtrace.status() { diff --git a/candle-examples/examples/llama_multiprocess/main.rs b/candle-examples/examples/llama_multiprocess/main.rs index 3b03b873..f540e084 100644 --- a/candle-examples/examples/llama_multiprocess/main.rs +++ b/candle-examples/examples/llama_multiprocess/main.rs @@ -76,7 +76,7 @@ struct Args { #[arg(long)] dtype: Option<String>, - #[arg(long)] + #[arg(long, default_value = "v3-8b")] which: Which, #[arg(long, default_value = "nccl_id.txt")] @@ -219,6 +219,9 @@ fn main() -> Result<()> { let next_token = logits_processor.sample(&logits)?; tokens.push(next_token); new_tokens.push(next_token); + if Some(next_token) == config.eos_token_id { + break; + } if rank == 0 { if let Some(t) = tokenizer.next_token(next_token)? { print!("{t}"); @@ -226,6 +229,7 @@ fn main() -> Result<()> { } } } + println!(); if rank == 0 { let dt = start_gen.elapsed(); println!( diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs index 414b1242..1fbf566c 100644 --- a/candle-examples/examples/llama_multiprocess/model.rs +++ b/candle-examples/examples/llama_multiprocess/model.rs @@ -1,15 +1,14 @@ use candle::backend::BackendStorage; use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D}; +use candle_nn::var_builder::ShardedVarBuilder as VarBuilder; use candle_nn::{Embedding, Linear, Module, RmsNorm}; use cudarc::nccl::safe::{Comm, ReduceOp}; -use half::{bf16, f16}; -use serde::Deserialize; use std::rc::Rc; use std::sync::{Arc, Mutex}; use super::MAX_SEQ_LEN; -use candle_nn::var_builder::ShardedVarBuilder as VarBuilder; +pub type Config = candle_transformers::models::llama::LlamaConfig; struct TensorParallelColumnLinear { linear: Linear, @@ -26,7 +25,7 @@ impl TensorParallelColumnLinear { struct TensorParallelRowLinear { linear: Linear, - comm: Rc<Comm>, + all_reduce: AllReduce, } struct AllReduce { @@ -36,8 +35,6 @@ struct AllReduce { /// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html /// But for this example purposes, this will work unsafe impl Sync for AllReduce {} -/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html -/// But for this example purposes, this will work unsafe impl Send for AllReduce {} impl CustomOp1 for AllReduce { @@ -46,7 +43,7 @@ impl CustomOp1 for AllReduce { } fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> { - todo!("implement allreduce for cpu is not necessary for single node"); + candle::bail!("AllReduce is never used on cpu") } #[cfg(feature = "cuda")] @@ -56,47 +53,49 @@ impl CustomOp1 for AllReduce { l: &Layout, ) -> Result<(candle::CudaStorage, Shape)> { use candle::cuda_backend::WrapErr; + use cudarc::driver::DeviceSlice; + use half::{bf16, f16}; + let elem_count = l.shape().elem_count(); let dev = s.device().clone(); - match s.dtype() { + let dst = match s.dtype() { DType::BF16 => { let s = s.as_cuda_slice::<bf16>()?; - // let s = match l.contiguous_offsets() { - // None => Err(Error::Wrapped("input has to be contiguous".into()))?, - // Some((o1, o2)) => s.slice(o1..o2), - // }; + let s = match l.contiguous_offsets() { + Some((0, l)) if l == s.len() => s, + Some(_) | None => candle::bail!("input has to be contiguous"), + }; let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?; - self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap(); - let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev); - Ok((dst, l.shape().clone())) + self.comm + .all_reduce(s, &mut dst, &ReduceOp::Sum) + .map_err(candle::Error::debug)?; + candle::CudaStorage::wrap_cuda_slice(dst, dev) } DType::F16 => { let s = s.as_cuda_slice::<f16>()?; - // let s = match l.contiguous_offsets() { - // None => Err(Error::Wrapped("input has to be contiguous".into()))?, - // Some((o1, o2)) => s.slice(o1..o2), - // }; + let s = match l.contiguous_offsets() { + Some((0, l)) if l == s.len() => s, + Some(_) | None => candle::bail!("input has to be contiguous"), + }; let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?; - self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap(); - let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev); - Ok((dst, l.shape().clone())) + self.comm + .all_reduce(s, &mut dst, &ReduceOp::Sum) + .map_err(candle::Error::debug)?; + candle::CudaStorage::wrap_cuda_slice(dst, dev) } dtype => candle::bail!("unsupported dtype {dtype:?}"), - } + }; + Ok((dst, l.shape().clone())) } } -fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> { - x.apply_op1(AllReduce { comm: comm.clone() }) -} - impl TensorParallelRowLinear { fn new(linear: Linear, comm: Rc<Comm>) -> Self { - Self { linear, comm } + let all_reduce = AllReduce { comm }; + Self { linear, all_reduce } } fn forward(&self, x: &Tensor) -> Result<Tensor> { - let x = self.linear.forward(x)?; - all_reduce_sum(&x, &self.comm) + self.linear.forward(x)?.apply_op1_no_bwd(&self.all_reduce) } } @@ -137,23 +136,6 @@ impl TensorParallelRowLinear { } } -#[derive(Deserialize)] -pub struct Config { - pub hidden_size: usize, - pub intermediate_size: usize, - pub vocab_size: usize, - pub num_hidden_layers: usize, - pub num_attention_heads: usize, - pub num_key_value_heads: usize, - pub rms_norm_eps: f64, - #[serde(default = "default_rope")] - pub rope_theta: f32, -} - -fn default_rope() -> f32 { - 10_000.0 -} - #[derive(Clone)] pub struct Cache { #[allow(clippy::type_complexity)] @@ -281,9 +263,7 @@ impl CausalSelfAttention { let v = v.transpose(1, 2)?; let softmax_scale = 1f32 / (self.head_dim as f32).sqrt(); let y = candle_flash_attn::flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)? - .transpose(1, 2)?; - // Convert to contiguous as matmul doesn't support strided vs for now. - let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?; + .reshape((b_sz, seq_len, hidden_size))?; let y = self.o_proj.forward(&y)?; Ok(y) } @@ -304,7 +284,7 @@ impl CausalSelfAttention { qkv_proj, o_proj, num_attention_heads: cfg.num_attention_heads / comm.world_size(), - num_key_value_heads: cfg.num_key_value_heads / comm.world_size(), + num_key_value_heads: cfg.num_key_value_heads() / comm.world_size(), head_dim: cfg.hidden_size / cfg.num_attention_heads, cache: cache.clone(), }) @@ -318,18 +298,6 @@ struct Mlp { } impl Mlp { - fn new( - c_fc1: TensorParallelColumnLinear, - c_fc2: TensorParallelColumnLinear, - c_proj: TensorParallelRowLinear, - ) -> Self { - Self { - c_fc1, - c_fc2, - c_proj, - } - } - fn forward(&self, x: &Tensor) -> Result<Tensor> { let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; self.c_proj.forward(&x) @@ -339,7 +307,11 @@ impl Mlp { let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?; let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?; let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm)?; - Ok(Self::new(c_fc1, c_fc2, c_proj)) + Ok(Self { + c_fc1, + c_fc2, + c_proj, + }) } } @@ -430,10 +402,8 @@ impl Llama { cfg, comm.clone(), ) - .unwrap() }) - .collect(); - + .collect::<Result<Vec<_>>>()?; Ok(Self::new(wte, blocks, norm, lm_head)) } } diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index 945c0e17..57d2f593 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -20,6 +20,12 @@ pub struct LlamaConfig { pub eos_token_id: Option<u32>, } +impl LlamaConfig { + pub fn num_key_value_heads(&self) -> usize { + self.num_key_value_heads.unwrap_or(self.num_attention_heads) + } +} + fn default_rope() -> f32 { 10_000.0 } @@ -32,7 +38,7 @@ impl LlamaConfig { vocab_size: self.vocab_size, num_hidden_layers: self.num_hidden_layers, num_attention_heads: self.num_attention_heads, - num_key_value_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads), + num_key_value_heads: self.num_key_value_heads(), rms_norm_eps: self.rms_norm_eps, rope_theta: self.rope_theta, use_flash_attn, |