summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/gemma/main.rs5
-rw-r--r--candle-flash-attn/kernels/flash_fwd_launch_template.h4
-rw-r--r--candle-flash-attn/src/lib.rs4
-rw-r--r--candle-transformers/src/models/gemma.rs62
4 files changed, 55 insertions, 20 deletions
diff --git a/candle-examples/examples/gemma/main.rs b/candle-examples/examples/gemma/main.rs
index a5f7d591..31c55618 100644
--- a/candle-examples/examples/gemma/main.rs
+++ b/candle-examples/examples/gemma/main.rs
@@ -193,6 +193,9 @@ struct Args {
/// The model to use.
#[arg(long, default_value = "2b")]
which: Which,
+
+ #[arg(long)]
+ use_flash_attn: bool,
}
fn main() -> Result<()> {
@@ -270,7 +273,7 @@ fn main() -> Result<()> {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
- let model = Model::new(&config, vb)?;
+ let model = Model::new(args.use_flash_attn, &config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
diff --git a/candle-flash-attn/kernels/flash_fwd_launch_template.h b/candle-flash-attn/kernels/flash_fwd_launch_template.h
index 66ab6206..002dd8ec 100644
--- a/candle-flash-attn/kernels/flash_fwd_launch_template.h
+++ b/candle-flash-attn/kernels/flash_fwd_launch_template.h
@@ -42,6 +42,10 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
+ if (smem_size >= 48 * 1024) {
+ cudaFuncSetAttribute(
+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
+ }
// int ctas_per_sm;
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs
index 21a06b5e..f171a986 100644
--- a/candle-flash-attn/src/lib.rs
+++ b/candle-flash-attn/src/lib.rs
@@ -139,7 +139,9 @@ impl FlashAttn {
let elem_count = out_shape.elem_count();
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
- let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
+ let softmax_lse = dev
+ .alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q)
+ .w()?;
let is_bf16 = if is_bf16 { 1 } else { 0 };
diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs
index 3bde88b4..1cfef59e 100644
--- a/candle-transformers/src/models/gemma.rs
+++ b/candle-transformers/src/models/gemma.rs
@@ -73,13 +73,6 @@ struct RotaryEmbedding {
cos: Tensor,
}
-fn rotate_half(xs: &Tensor) -> Result<Tensor> {
- let last_dim = xs.dim(D::Minus1)?;
- let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
- let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
- Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
-}
-
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.head_dim;
@@ -94,7 +87,6 @@ impl RotaryEmbedding {
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
- let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
@@ -110,10 +102,8 @@ impl RotaryEmbedding {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
- let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
- let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
- let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
- let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
+ let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
+ let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
@@ -163,10 +153,16 @@ struct Attention {
head_dim: usize,
rotary_emb: Arc<RotaryEmbedding>,
kv_cache: Option<(Tensor, Tensor)>,
+ use_flash_attn: bool,
}
impl Attention {
- fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ fn new(
+ rotary_emb: Arc<RotaryEmbedding>,
+ use_flash_attn: bool,
+ cfg: &Config,
+ vb: VarBuilder,
+ ) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
@@ -188,6 +184,7 @@ impl Attention {
head_dim,
rotary_emb,
kv_cache: None,
+ use_flash_attn,
})
}
@@ -231,7 +228,14 @@ impl Attention {
let value_states =
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
- let attn_output = {
+ let attn_output = if self.use_flash_attn {
+ // flash-attn expects (b_sz, seq_len, nheads, head_dim)
+ let q = query_states.transpose(1, 2)?;
+ let k = key_states.transpose(1, 2)?;
+ let v = value_states.transpose(1, 2)?;
+ let scale = 1f32 / (self.head_dim as f32).sqrt();
+ flash_attn(&q, &k, &v, scale, attention_mask.is_some())?.transpose(1, 2)?
+ } else {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
@@ -253,6 +257,22 @@ impl Attention {
}
}
+#[cfg(feature = "flash-attn")]
+fn flash_attn(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ softmax_scale: f32,
+ causal: bool,
+) -> Result<Tensor> {
+ candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
+}
+
+#[cfg(not(feature = "flash-attn"))]
+fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
+ unimplemented!("compile with '--features flash-attn'")
+}
+
#[derive(Debug, Clone)]
struct DecoderLayer {
self_attn: Attention,
@@ -262,8 +282,13 @@ struct DecoderLayer {
}
impl DecoderLayer {
- fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
- let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
+ fn new(
+ rotary_emb: Arc<RotaryEmbedding>,
+ use_flash_attn: bool,
+ cfg: &Config,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let self_attn = Attention::new(rotary_emb, use_flash_attn, cfg, vb.pp("self_attn"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
let input_layernorm =
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
@@ -312,7 +337,7 @@ pub struct Model {
}
impl Model {
- pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ pub fn new(use_flash_attn: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_m = vb.pp("model");
let embed_tokens =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
@@ -320,7 +345,8 @@ impl Model {
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_m.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
- let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
+ let layer =
+ DecoderLayer::new(rotary_emb.clone(), use_flash_attn, cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;