summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama_multiprocess/model.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/llama_multiprocess/model.rs')
-rw-r--r--candle-examples/examples/llama_multiprocess/model.rs66
1 files changed, 44 insertions, 22 deletions
diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs
index 4e46b526..e902734f 100644
--- a/candle-examples/examples/llama_multiprocess/model.rs
+++ b/candle-examples/examples/llama_multiprocess/model.rs
@@ -1,6 +1,6 @@
-use candle::{DType, Device, IndexOp, Result, Tensor, D};
+use candle::backend::BackendStorage;
+use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
use candle_nn::{Embedding, Linear, VarBuilder};
-use cudarc::driver::safe::CudaSlice;
use cudarc::nccl::safe::{Comm, ReduceOp};
use half::f16;
use std::collections::HashMap;
@@ -27,14 +27,42 @@ struct TensorParallelRowLinear {
comm: Rc<Comm>,
}
+struct AllReduce {
+ comm: Rc<Comm>,
+}
+
+impl CustomOp1 for AllReduce {
+ fn name(&self) -> &'static str {
+ "allreduce"
+ }
+
+ fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
+ todo!("implement allreduce for cpu is not necessary for single node");
+ }
+
+ #[cfg(feature = "cuda")]
+ fn cuda_fwd(
+ &self,
+ s: &candle::CudaStorage,
+ l: &Layout,
+ ) -> Result<(candle::CudaStorage, Shape)> {
+ use candle::cuda_backend::WrapErr;
+ let elem_count = l.shape().elem_count();
+ let dev = s.device().clone();
+ let s = s.as_cuda_slice::<f16>()?;
+ // let s = match l.contiguous_offsets() {
+ // None => Err(Error::Wrapped("input has to be contiguous".into()))?,
+ // Some((o1, o2)) => s.slice(o1..o2),
+ // };
+ let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
+ self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
+ let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
+ Ok((dst, l.shape().clone()))
+ }
+}
+
fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
- Ok(x.clone())
- // let n = x.shape().elem_count();
- // let cuda_slice: CudaSlice<f16> = x.try_into()?;
- // let dev = cuda_slice.device();
- // let mut slice_receive = dev.alloc_zeros(n).unwrap();
- // comm.all_reduce(cuda_slice, &mut slice_receive, &ReduceOp::Sum).unwrap();
- // Tensor::from_raw_storage(slice_receive, x.shape())
+ x.custom_op1(AllReduce { comm: comm.clone() })
}
impl TensorParallelRowLinear {
@@ -187,11 +215,11 @@ impl RmsNorm {
let in_dtype = x.dtype();
// This is a no-op if x's dtype is already f32.
let x = x.to_dtype(DType::F32)?;
- let (b_sz, seq_len, hidden_size) = x.shape().r3()?;
+ let (b_sz, seq_len, hidden_size) = x.shape().dims3()?;
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
let x_normed = (x / (norm_x + 1e-6)?.sqrt()?)?;
- let size = self.scale.shape().r1()?;
+ let size = self.scale.shape().dims1()?;
let scale = self
.scale
.to_dtype(DType::F32)?
@@ -213,7 +241,7 @@ struct CausalSelfAttention {
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
- let (b_sz, _, seq_len, n_embd) = x.shape().r4()?;
+ let (b_sz, _, seq_len, n_embd) = x.shape().dims4()?;
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?;
@@ -227,7 +255,7 @@ impl CausalSelfAttention {
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
let x_dtype = x.dtype();
- let (b_sz, seq_len, _) = x.shape().r3()?;
+ let (b_sz, seq_len, _) = x.shape().dims3()?;
let qkv = self.qkv_proj.forward(x)?;
let n_embd = self.n_head * self.head_dim;
@@ -302,7 +330,7 @@ impl CausalSelfAttention {
if n_rep == 1 {
Ok(x)
} else {
- let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().r4()?;
+ let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().dims4()?;
let x = x
.unsqueeze(2)?
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
@@ -312,10 +340,6 @@ impl CausalSelfAttention {
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
- let size_in = cfg.hidden_size;
- let size_q = (cfg.hidden_size / cfg.n_head) * cfg.n_head;
- let size_kv = (cfg.hidden_size / cfg.n_head) * cfg.n_key_value_head;
-
let qkv_proj = TensorParallelColumnLinear::load_multi(
vb.clone(),
&["q_proj", "k_proj", "v_proj"],
@@ -364,9 +388,7 @@ impl Mlp {
self.c_proj.forward(&x)
}
- fn load(vb: VarBuilder, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
- let h_size = cfg.hidden_size;
- let i_size = cfg.intermediate_size;
+ fn load(vb: VarBuilder, _cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?;
let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?;
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm.clone())?;
@@ -433,7 +455,7 @@ impl Llama {
}
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
- let (_b_sz, seq_len) = x.shape().r2()?;
+ let (_b_sz, seq_len) = x.shape().dims2()?;
let mut x = self.wte.forward(x)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, index_pos, block_idx)?;