diff options
-rw-r--r-- | candle-examples/examples/gemma/main.rs | 5 | ||||
-rw-r--r-- | candle-flash-attn/kernels/flash_fwd_launch_template.h | 4 | ||||
-rw-r--r-- | candle-flash-attn/src/lib.rs | 4 | ||||
-rw-r--r-- | candle-transformers/src/models/gemma.rs | 62 |
4 files changed, 55 insertions, 20 deletions
diff --git a/candle-examples/examples/gemma/main.rs b/candle-examples/examples/gemma/main.rs index a5f7d591..31c55618 100644 --- a/candle-examples/examples/gemma/main.rs +++ b/candle-examples/examples/gemma/main.rs @@ -193,6 +193,9 @@ struct Args { /// The model to use. #[arg(long, default_value = "2b")] which: Which, + + #[arg(long)] + use_flash_attn: bool, } fn main() -> Result<()> { @@ -270,7 +273,7 @@ fn main() -> Result<()> { DType::F32 }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; - let model = Model::new(&config, vb)?; + let model = Model::new(args.use_flash_attn, &config, vb)?; println!("loaded the model in {:?}", start.elapsed()); diff --git a/candle-flash-attn/kernels/flash_fwd_launch_template.h b/candle-flash-attn/kernels/flash_fwd_launch_template.h index 66ab6206..002dd8ec 100644 --- a/candle-flash-attn/kernels/flash_fwd_launch_template.h +++ b/candle-flash-attn/kernels/flash_fwd_launch_template.h @@ -42,6 +42,10 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>; // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } // int ctas_per_sm; // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index 21a06b5e..f171a986 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -139,7 +139,9 @@ impl FlashAttn { let elem_count = out_shape.elem_count(); let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?; - let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?; + let softmax_lse = dev + .alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q) + .w()?; let is_bf16 = if is_bf16 { 1 } else { 0 }; diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs index 3bde88b4..1cfef59e 100644 --- a/candle-transformers/src/models/gemma.rs +++ b/candle-transformers/src/models/gemma.rs @@ -73,13 +73,6 @@ struct RotaryEmbedding { cos: Tensor, } -fn rotate_half(xs: &Tensor) -> Result<Tensor> { - let last_dim = xs.dim(D::Minus1)?; - let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?; - let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?; - Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1) -} - impl RotaryEmbedding { fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> { let dim = cfg.head_dim; @@ -94,7 +87,6 @@ impl RotaryEmbedding { .to_dtype(dtype)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; - let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; Ok(Self { sin: freqs.sin()?, cos: freqs.cos()?, @@ -110,10 +102,8 @@ impl RotaryEmbedding { let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; - let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) - let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) - let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?; - let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?; + let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; Ok((q_embed, k_embed)) } } @@ -163,10 +153,16 @@ struct Attention { head_dim: usize, rotary_emb: Arc<RotaryEmbedding>, kv_cache: Option<(Tensor, Tensor)>, + use_flash_attn: bool, } impl Attention { - fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> { + fn new( + rotary_emb: Arc<RotaryEmbedding>, + use_flash_attn: bool, + cfg: &Config, + vb: VarBuilder, + ) -> Result<Self> { let hidden_sz = cfg.hidden_size; let num_heads = cfg.num_attention_heads; let num_kv_heads = cfg.num_key_value_heads; @@ -188,6 +184,7 @@ impl Attention { head_dim, rotary_emb, kv_cache: None, + use_flash_attn, }) } @@ -231,7 +228,14 @@ impl Attention { let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?; - let attn_output = { + let attn_output = if self.use_flash_attn { + // flash-attn expects (b_sz, seq_len, nheads, head_dim) + let q = query_states.transpose(1, 2)?; + let k = key_states.transpose(1, 2)?; + let v = value_states.transpose(1, 2)?; + let scale = 1f32 / (self.head_dim as f32).sqrt(); + flash_attn(&q, &k, &v, scale, attention_mask.is_some())?.transpose(1, 2)? + } else { let scale = 1f64 / f64::sqrt(self.head_dim as f64); let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?; @@ -253,6 +257,22 @@ impl Attention { } } +#[cfg(feature = "flash-attn")] +fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result<Tensor> { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> { + unimplemented!("compile with '--features flash-attn'") +} + #[derive(Debug, Clone)] struct DecoderLayer { self_attn: Attention, @@ -262,8 +282,13 @@ struct DecoderLayer { } impl DecoderLayer { - fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> { - let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; + fn new( + rotary_emb: Arc<RotaryEmbedding>, + use_flash_attn: bool, + cfg: &Config, + vb: VarBuilder, + ) -> Result<Self> { + let self_attn = Attention::new(rotary_emb, use_flash_attn, cfg, vb.pp("self_attn"))?; let mlp = MLP::new(cfg, vb.pp("mlp"))?; let input_layernorm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; @@ -312,7 +337,7 @@ pub struct Model { } impl Model { - pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + pub fn new(use_flash_attn: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> { let vb_m = vb.pp("model"); let embed_tokens = candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; @@ -320,7 +345,8 @@ impl Model { let mut layers = Vec::with_capacity(cfg.num_hidden_layers); let vb_l = vb_m.pp("layers"); for layer_idx in 0..cfg.num_hidden_layers { - let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; + let layer = + DecoderLayer::new(rotary_emb.clone(), use_flash_attn, cfg, vb_l.pp(layer_idx))?; layers.push(layer) } let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; |