summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-20 22:19:46 +0200
committerGitHub <noreply@github.com>2024-04-20 22:19:46 +0200
commit587ee3bb6fd2b4c2b7bbe7e97751cac96249dd6d (patch)
tree122db3c84eccdef1dd0451b0c939104ab03a4113
parentdd78422701e9c6f3ca74218e8aedcf032c6c7215 (diff)
downloadcandle-587ee3bb6fd2b4c2b7bbe7e97751cac96249dd6d.tar.gz
candle-587ee3bb6fd2b4c2b7bbe7e97751cac96249dd6d.tar.bz2
candle-587ee3bb6fd2b4c2b7bbe7e97751cac96249dd6d.zip
Small cleanups to the llama multi-process example. (#2098)
-rw-r--r--candle-core/src/error.rs6
-rw-r--r--candle-examples/examples/llama_multiprocess/main.rs6
-rw-r--r--candle-examples/examples/llama_multiprocess/model.rs104
-rw-r--r--candle-transformers/src/models/llama.rs8
4 files changed, 54 insertions, 70 deletions
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs
index 60ddea11..e7112e2e 100644
--- a/candle-core/src/error.rs
+++ b/candle-core/src/error.rs
@@ -219,10 +219,14 @@ impl Error {
Self::Wrapped(Box::new(err)).bt()
}
- pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self {
+ pub fn msg(err: impl std::error::Error) -> Self {
Self::Msg(err.to_string()).bt()
}
+ pub fn debug(err: impl std::fmt::Debug) -> Self {
+ Self::Msg(format!("{err:?}")).bt()
+ }
+
pub fn bt(self) -> Self {
let backtrace = std::backtrace::Backtrace::capture();
match backtrace.status() {
diff --git a/candle-examples/examples/llama_multiprocess/main.rs b/candle-examples/examples/llama_multiprocess/main.rs
index 3b03b873..f540e084 100644
--- a/candle-examples/examples/llama_multiprocess/main.rs
+++ b/candle-examples/examples/llama_multiprocess/main.rs
@@ -76,7 +76,7 @@ struct Args {
#[arg(long)]
dtype: Option<String>,
- #[arg(long)]
+ #[arg(long, default_value = "v3-8b")]
which: Which,
#[arg(long, default_value = "nccl_id.txt")]
@@ -219,6 +219,9 @@ fn main() -> Result<()> {
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
+ if Some(next_token) == config.eos_token_id {
+ break;
+ }
if rank == 0 {
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
@@ -226,6 +229,7 @@ fn main() -> Result<()> {
}
}
}
+ println!();
if rank == 0 {
let dt = start_gen.elapsed();
println!(
diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs
index 414b1242..1fbf566c 100644
--- a/candle-examples/examples/llama_multiprocess/model.rs
+++ b/candle-examples/examples/llama_multiprocess/model.rs
@@ -1,15 +1,14 @@
use candle::backend::BackendStorage;
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
+use candle_nn::var_builder::ShardedVarBuilder as VarBuilder;
use candle_nn::{Embedding, Linear, Module, RmsNorm};
use cudarc::nccl::safe::{Comm, ReduceOp};
-use half::{bf16, f16};
-use serde::Deserialize;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use super::MAX_SEQ_LEN;
-use candle_nn::var_builder::ShardedVarBuilder as VarBuilder;
+pub type Config = candle_transformers::models::llama::LlamaConfig;
struct TensorParallelColumnLinear {
linear: Linear,
@@ -26,7 +25,7 @@ impl TensorParallelColumnLinear {
struct TensorParallelRowLinear {
linear: Linear,
- comm: Rc<Comm>,
+ all_reduce: AllReduce,
}
struct AllReduce {
@@ -36,8 +35,6 @@ struct AllReduce {
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
/// But for this example purposes, this will work
unsafe impl Sync for AllReduce {}
-/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
-/// But for this example purposes, this will work
unsafe impl Send for AllReduce {}
impl CustomOp1 for AllReduce {
@@ -46,7 +43,7 @@ impl CustomOp1 for AllReduce {
}
fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
- todo!("implement allreduce for cpu is not necessary for single node");
+ candle::bail!("AllReduce is never used on cpu")
}
#[cfg(feature = "cuda")]
@@ -56,47 +53,49 @@ impl CustomOp1 for AllReduce {
l: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::WrapErr;
+ use cudarc::driver::DeviceSlice;
+ use half::{bf16, f16};
+
let elem_count = l.shape().elem_count();
let dev = s.device().clone();
- match s.dtype() {
+ let dst = match s.dtype() {
DType::BF16 => {
let s = s.as_cuda_slice::<bf16>()?;
- // let s = match l.contiguous_offsets() {
- // None => Err(Error::Wrapped("input has to be contiguous".into()))?,
- // Some((o1, o2)) => s.slice(o1..o2),
- // };
+ let s = match l.contiguous_offsets() {
+ Some((0, l)) if l == s.len() => s,
+ Some(_) | None => candle::bail!("input has to be contiguous"),
+ };
let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
- self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
- let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
- Ok((dst, l.shape().clone()))
+ self.comm
+ .all_reduce(s, &mut dst, &ReduceOp::Sum)
+ .map_err(candle::Error::debug)?;
+ candle::CudaStorage::wrap_cuda_slice(dst, dev)
}
DType::F16 => {
let s = s.as_cuda_slice::<f16>()?;
- // let s = match l.contiguous_offsets() {
- // None => Err(Error::Wrapped("input has to be contiguous".into()))?,
- // Some((o1, o2)) => s.slice(o1..o2),
- // };
+ let s = match l.contiguous_offsets() {
+ Some((0, l)) if l == s.len() => s,
+ Some(_) | None => candle::bail!("input has to be contiguous"),
+ };
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
- self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
- let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
- Ok((dst, l.shape().clone()))
+ self.comm
+ .all_reduce(s, &mut dst, &ReduceOp::Sum)
+ .map_err(candle::Error::debug)?;
+ candle::CudaStorage::wrap_cuda_slice(dst, dev)
}
dtype => candle::bail!("unsupported dtype {dtype:?}"),
- }
+ };
+ Ok((dst, l.shape().clone()))
}
}
-fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
- x.apply_op1(AllReduce { comm: comm.clone() })
-}
-
impl TensorParallelRowLinear {
fn new(linear: Linear, comm: Rc<Comm>) -> Self {
- Self { linear, comm }
+ let all_reduce = AllReduce { comm };
+ Self { linear, all_reduce }
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let x = self.linear.forward(x)?;
- all_reduce_sum(&x, &self.comm)
+ self.linear.forward(x)?.apply_op1_no_bwd(&self.all_reduce)
}
}
@@ -137,23 +136,6 @@ impl TensorParallelRowLinear {
}
}
-#[derive(Deserialize)]
-pub struct Config {
- pub hidden_size: usize,
- pub intermediate_size: usize,
- pub vocab_size: usize,
- pub num_hidden_layers: usize,
- pub num_attention_heads: usize,
- pub num_key_value_heads: usize,
- pub rms_norm_eps: f64,
- #[serde(default = "default_rope")]
- pub rope_theta: f32,
-}
-
-fn default_rope() -> f32 {
- 10_000.0
-}
-
#[derive(Clone)]
pub struct Cache {
#[allow(clippy::type_complexity)]
@@ -281,9 +263,7 @@ impl CausalSelfAttention {
let v = v.transpose(1, 2)?;
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
let y = candle_flash_attn::flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?
- .transpose(1, 2)?;
- // Convert to contiguous as matmul doesn't support strided vs for now.
- let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
+ .reshape((b_sz, seq_len, hidden_size))?;
let y = self.o_proj.forward(&y)?;
Ok(y)
}
@@ -304,7 +284,7 @@ impl CausalSelfAttention {
qkv_proj,
o_proj,
num_attention_heads: cfg.num_attention_heads / comm.world_size(),
- num_key_value_heads: cfg.num_key_value_heads / comm.world_size(),
+ num_key_value_heads: cfg.num_key_value_heads() / comm.world_size(),
head_dim: cfg.hidden_size / cfg.num_attention_heads,
cache: cache.clone(),
})
@@ -318,18 +298,6 @@ struct Mlp {
}
impl Mlp {
- fn new(
- c_fc1: TensorParallelColumnLinear,
- c_fc2: TensorParallelColumnLinear,
- c_proj: TensorParallelRowLinear,
- ) -> Self {
- Self {
- c_fc1,
- c_fc2,
- c_proj,
- }
- }
-
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
self.c_proj.forward(&x)
@@ -339,7 +307,11 @@ impl Mlp {
let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?;
let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?;
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm)?;
- Ok(Self::new(c_fc1, c_fc2, c_proj))
+ Ok(Self {
+ c_fc1,
+ c_fc2,
+ c_proj,
+ })
}
}
@@ -430,10 +402,8 @@ impl Llama {
cfg,
comm.clone(),
)
- .unwrap()
})
- .collect();
-
+ .collect::<Result<Vec<_>>>()?;
Ok(Self::new(wte, blocks, norm, lm_head))
}
}
diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs
index 945c0e17..57d2f593 100644
--- a/candle-transformers/src/models/llama.rs
+++ b/candle-transformers/src/models/llama.rs
@@ -20,6 +20,12 @@ pub struct LlamaConfig {
pub eos_token_id: Option<u32>,
}
+impl LlamaConfig {
+ pub fn num_key_value_heads(&self) -> usize {
+ self.num_key_value_heads.unwrap_or(self.num_attention_heads)
+ }
+}
+
fn default_rope() -> f32 {
10_000.0
}
@@ -32,7 +38,7 @@ impl LlamaConfig {
vocab_size: self.vocab_size,
num_hidden_layers: self.num_hidden_layers,
num_attention_heads: self.num_attention_heads,
- num_key_value_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads),
+ num_key_value_heads: self.num_key_value_heads(),
rms_norm_eps: self.rms_norm_eps,
rope_theta: self.rope_theta,
use_flash_attn,