diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-09 10:46:09 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-09 10:46:09 +0100 |
commit | 3cd7e7b51dc7bf49215b136702f5bc3cd4642144 (patch) | |
tree | d024ff2edebb0e7af84d471d6daaa2be90b12c9a /candle-examples | |
parent | 722c50bb0ce18d749edcf86268238e5d9c9ee57e (diff) | |
download | candle-3cd7e7b51dc7bf49215b136702f5bc3cd4642144.tar.gz candle-3cd7e7b51dc7bf49215b136702f5bc3cd4642144.tar.bz2 candle-3cd7e7b51dc7bf49215b136702f5bc3cd4642144.zip |
Fuse the rel-pos additions via a custom-op. (#786)
* Fuse the rel-pos additions via a custom-op.
* Run with rayon.
* Add more tracing.
Diffstat (limited to 'candle-examples')
-rw-r--r-- | candle-examples/Cargo.toml | 1 | ||||
-rw-r--r-- | candle-examples/examples/segment-anything/model_image_encoder.rs | 91 |
2 files changed, 86 insertions, 6 deletions
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 9035eae0..6f8792a3 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -24,6 +24,7 @@ intel-mkl-src = { workspace = true, optional = true } cudarc = { workspace = true, optional = true } half = { workspace = true, optional = true } image = { workspace = true } +rayon = { workspace = true } [dev-dependencies] anyhow = { workspace = true } diff --git a/candle-examples/examples/segment-anything/model_image_encoder.rs b/candle-examples/examples/segment-anything/model_image_encoder.rs index f997170d..76cd15d0 100644 --- a/candle-examples/examples/segment-anything/model_image_encoder.rs +++ b/candle-examples/examples/segment-anything/model_image_encoder.rs @@ -34,6 +34,70 @@ impl Module for PatchEmbed { } } +// A custom op to make add_decomposed_rel_pos faster. Most of the time is spent on the final +// addition in the case where b = 12, q_h = q_w = 4096, k_h = k_w = 4096 +// (attn.reshape((b, q_h, q_w, k_h, k_w))? +// + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)? +// .reshape((b, q_h * q_w, k_h * k_w)) +// Ideally we would perform this operation in place but this is not supported in candle at the +// moment. We should also investigate using f16 rather than f32. +struct Add3(usize, usize, usize, usize, usize); +impl candle::CustomOp3 for Add3 { + fn name(&self) -> &'static str { + "add3" + } + + fn cpu_fwd( + &self, + s1: &candle::CpuStorage, + l1: &candle::Layout, + s2: &candle::CpuStorage, + l2: &candle::Layout, + s3: &candle::CpuStorage, + l3: &candle::Layout, + ) -> Result<(candle::CpuStorage, candle::Shape)> { + use rayon::prelude::*; + + let Add3(b, q_h, q_w, k_h, k_w) = *self; + let s1 = s1.as_slice::<f32>()?; + let s1 = match l1.contiguous_offsets() { + None => candle::bail!("input1 has to be contiguous"), + Some((o1, o2)) => &s1[o1..o2], + }; + let s2 = s2.as_slice::<f32>()?; + let s2 = match l2.contiguous_offsets() { + None => candle::bail!("input2 has to be contiguous"), + Some((o1, o2)) => &s2[o1..o2], + }; + let s3 = s3.as_slice::<f32>()?; + let s3 = match l3.contiguous_offsets() { + None => candle::bail!("input3 has to be contiguous"), + Some((o1, o2)) => &s3[o1..o2], + }; + let mut dst = vec![0f32; b * q_h * q_w * k_h * k_w]; + dst.par_chunks_exact_mut(k_h * k_w) + .enumerate() + .for_each(|(b_idx, dst)| { + let s1_idx = b_idx * k_h * k_w; + let s2_idx = b_idx * k_h; + let s3_idx = b_idx * k_w; + for h_idx in 0..k_h { + let s1_idx = s1_idx + h_idx * k_w; + let s2_idx = s2_idx + h_idx; + let dst_idx = h_idx * k_w; + for w_idx in 0..k_w { + let s1_idx = s1_idx + w_idx; + let s3_idx = s3_idx + w_idx; + let dst_idx = dst_idx + w_idx; + dst[dst_idx] = s1[s1_idx] + s2[s2_idx] + s3[s3_idx] + } + } + }); + let dst = candle::WithDType::to_cpu_storage_owned(dst); + Ok((dst, (b, q_h * q_w, k_h * k_w).into())) + } +} + #[derive(Debug)] struct Attention { qkv: crate::Linear, @@ -42,6 +106,7 @@ struct Attention { scale: f64, rel_pos_hw: Option<(Tensor, Tensor)>, span: tracing::Span, + span_matmul: tracing::Span, span_rel_pos: tracing::Span, span_softmax: tracing::Span, } @@ -56,6 +121,7 @@ impl Attention { vb: VarBuilder, ) -> Result<Self> { let span = tracing::span!(tracing::Level::TRACE, "attention"); + let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul"); let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos"); let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm"); let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?; @@ -76,6 +142,7 @@ impl Attention { scale, rel_pos_hw, span, + span_matmul, span_rel_pos, span_softmax, }) @@ -101,10 +168,16 @@ impl Attention { .transpose(1, 2)? // -> bwhc .contiguous()? .matmul(&r_w.broadcast_left(b)?.t()?.contiguous()?)? // bwhc,bwck -> bwhk - .transpose(1, 2)?; - (attn.reshape((b, q_h, q_w, k_h, k_w))? - + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)? - .reshape((b, q_h * q_w, k_h * k_w)) + .transpose(1, 2)? + .contiguous()?; + if attn.device().is_cpu() { + let op = Add3(b, q_h, q_w, k_h, k_w); + attn.apply_op3_no_bwd(&rel_h, &rel_w, &op) + } else { + (attn.reshape((b, q_h, q_w, k_h, k_w))? + + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)? + .reshape((b, q_h * q_w, k_h * k_w)) + } } None => Ok(attn), } @@ -149,7 +222,10 @@ impl Module for Attention { let q = qkv.i(0)?; let k = qkv.i(1)?; let v = qkv.i(2)?; - let attn = (&q * self.scale)?.matmul(&k.t()?)?; + let attn = { + let _enter = self.span_matmul.enter(); + (&q * self.scale)?.matmul(&k.t()?)? + }; let attn = { let _enter = self.span_rel_pos.enter(); self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))? @@ -158,7 +234,10 @@ impl Module for Attention { let _enter = self.span_softmax.enter(); candle_nn::ops::softmax_last_dim(&attn)? }; - let attn = attn.matmul(&v)?; + let attn = { + let _enter = self.span_matmul.enter(); + attn.matmul(&v)? + }; let attn = attn .reshape((b, self.num_heads, h, w, c / self.num_heads))? .permute((0, 2, 3, 1, 4))? |