summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-09-26 10:23:43 +0200
committerGitHub <noreply@github.com>2024-09-26 10:23:43 +0200
commit10d47183c088ce449da13d74f07171c8106cd6dd (patch)
treeb91b0398fcb314e998b9f7f3b23877f63462b232 /candle-transformers
parentd01207dbf3fb0ad614e7915c8f5706fbc09902fb (diff)
downloadcandle-10d47183c088ce449da13d74f07171c8106cd6dd.tar.gz
candle-10d47183c088ce449da13d74f07171c8106cd6dd.tar.bz2
candle-10d47183c088ce449da13d74f07171c8106cd6dd.zip
Quantized version of flux. (#2500)
* Quantized version of flux. * More generic sampling. * Hook the quantized model. * Use the newly minted gguf file. * Fix for the quantized model. * Default to avoid the faster cuda kernels.
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/flux/mod.rs17
-rw-r--r--candle-transformers/src/models/flux/model.rs10
-rw-r--r--candle-transformers/src/models/flux/quantized_model.rs465
-rw-r--r--candle-transformers/src/models/flux/sampling.rs4
4 files changed, 490 insertions, 6 deletions
diff --git a/candle-transformers/src/models/flux/mod.rs b/candle-transformers/src/models/flux/mod.rs
index 763fa90d..b0c8a693 100644
--- a/candle-transformers/src/models/flux/mod.rs
+++ b/candle-transformers/src/models/flux/mod.rs
@@ -1,3 +1,20 @@
+use candle::{Result, Tensor};
+
+pub trait WithForward {
+ #[allow(clippy::too_many_arguments)]
+ fn forward(
+ &self,
+ img: &Tensor,
+ img_ids: &Tensor,
+ txt: &Tensor,
+ txt_ids: &Tensor,
+ timesteps: &Tensor,
+ y: &Tensor,
+ guidance: Option<&Tensor>,
+ ) -> Result<Tensor>;
+}
+
pub mod autoencoder;
pub mod model;
+pub mod quantized_model;
pub mod sampling;
diff --git a/candle-transformers/src/models/flux/model.rs b/candle-transformers/src/models/flux/model.rs
index 4e47873f..17b4eb25 100644
--- a/candle-transformers/src/models/flux/model.rs
+++ b/candle-transformers/src/models/flux/model.rs
@@ -109,14 +109,14 @@ fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result<Tensor> {
(fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec())
}
-fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
+pub(crate) fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
let q = apply_rope(q, pe)?.contiguous()?;
let k = apply_rope(k, pe)?.contiguous()?;
let x = scaled_dot_product_attention(&q, &k, v)?;
x.transpose(1, 2)?.flatten_from(2)
}
-fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
+pub(crate) fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
const TIME_FACTOR: f64 = 1000.;
const MAX_PERIOD: f64 = 10000.;
if dim % 2 == 1 {
@@ -144,7 +144,7 @@ pub struct EmbedNd {
}
impl EmbedNd {
- fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
+ pub fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
Self {
dim,
theta,
@@ -575,9 +575,11 @@ impl Flux {
final_layer,
})
}
+}
+impl super::WithForward for Flux {
#[allow(clippy::too_many_arguments)]
- pub fn forward(
+ fn forward(
&self,
img: &Tensor,
img_ids: &Tensor,
diff --git a/candle-transformers/src/models/flux/quantized_model.rs b/candle-transformers/src/models/flux/quantized_model.rs
new file mode 100644
index 00000000..0efeeab5
--- /dev/null
+++ b/candle-transformers/src/models/flux/quantized_model.rs
@@ -0,0 +1,465 @@
+use super::model::{attention, timestep_embedding, Config, EmbedNd};
+use crate::quantized_nn::{linear, linear_b, Linear};
+use crate::quantized_var_builder::VarBuilder;
+use candle::{DType, IndexOp, Result, Tensor, D};
+use candle_nn::{LayerNorm, RmsNorm};
+
+fn layer_norm(dim: usize, vb: VarBuilder) -> Result<LayerNorm> {
+ let ws = Tensor::ones(dim, DType::F32, vb.device())?;
+ Ok(LayerNorm::new_no_bias(ws, 1e-6))
+}
+
+#[derive(Debug, Clone)]
+pub struct MlpEmbedder {
+ in_layer: Linear,
+ out_layer: Linear,
+}
+
+impl MlpEmbedder {
+ fn new(in_sz: usize, h_sz: usize, vb: VarBuilder) -> Result<Self> {
+ let in_layer = linear(in_sz, h_sz, vb.pp("in_layer"))?;
+ let out_layer = linear(h_sz, h_sz, vb.pp("out_layer"))?;
+ Ok(Self {
+ in_layer,
+ out_layer,
+ })
+ }
+}
+
+impl candle::Module for MlpEmbedder {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.apply(&self.in_layer)?.silu()?.apply(&self.out_layer)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct QkNorm {
+ query_norm: RmsNorm,
+ key_norm: RmsNorm,
+}
+
+impl QkNorm {
+ fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
+ let query_norm = vb.get(dim, "query_norm.scale")?.dequantize(vb.device())?;
+ let query_norm = RmsNorm::new(query_norm, 1e-6);
+ let key_norm = vb.get(dim, "key_norm.scale")?.dequantize(vb.device())?;
+ let key_norm = RmsNorm::new(key_norm, 1e-6);
+ Ok(Self {
+ query_norm,
+ key_norm,
+ })
+ }
+}
+
+struct ModulationOut {
+ shift: Tensor,
+ scale: Tensor,
+ gate: Tensor,
+}
+
+impl ModulationOut {
+ fn scale_shift(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.broadcast_mul(&(&self.scale + 1.)?)?
+ .broadcast_add(&self.shift)
+ }
+
+ fn gate(&self, xs: &Tensor) -> Result<Tensor> {
+ self.gate.broadcast_mul(xs)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct Modulation1 {
+ lin: Linear,
+}
+
+impl Modulation1 {
+ fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
+ let lin = linear(dim, 3 * dim, vb.pp("lin"))?;
+ Ok(Self { lin })
+ }
+
+ fn forward(&self, vec_: &Tensor) -> Result<ModulationOut> {
+ let ys = vec_
+ .silu()?
+ .apply(&self.lin)?
+ .unsqueeze(1)?
+ .chunk(3, D::Minus1)?;
+ if ys.len() != 3 {
+ candle::bail!("unexpected len from chunk {ys:?}")
+ }
+ Ok(ModulationOut {
+ shift: ys[0].clone(),
+ scale: ys[1].clone(),
+ gate: ys[2].clone(),
+ })
+ }
+}
+
+#[derive(Debug, Clone)]
+struct Modulation2 {
+ lin: Linear,
+}
+
+impl Modulation2 {
+ fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
+ let lin = linear(dim, 6 * dim, vb.pp("lin"))?;
+ Ok(Self { lin })
+ }
+
+ fn forward(&self, vec_: &Tensor) -> Result<(ModulationOut, ModulationOut)> {
+ let ys = vec_
+ .silu()?
+ .apply(&self.lin)?
+ .unsqueeze(1)?
+ .chunk(6, D::Minus1)?;
+ if ys.len() != 6 {
+ candle::bail!("unexpected len from chunk {ys:?}")
+ }
+ let mod1 = ModulationOut {
+ shift: ys[0].clone(),
+ scale: ys[1].clone(),
+ gate: ys[2].clone(),
+ };
+ let mod2 = ModulationOut {
+ shift: ys[3].clone(),
+ scale: ys[4].clone(),
+ gate: ys[5].clone(),
+ };
+ Ok((mod1, mod2))
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct SelfAttention {
+ qkv: Linear,
+ norm: QkNorm,
+ proj: Linear,
+ num_heads: usize,
+}
+
+impl SelfAttention {
+ fn new(dim: usize, num_heads: usize, qkv_bias: bool, vb: VarBuilder) -> Result<Self> {
+ let head_dim = dim / num_heads;
+ let qkv = linear_b(dim, dim * 3, qkv_bias, vb.pp("qkv"))?;
+ let norm = QkNorm::new(head_dim, vb.pp("norm"))?;
+ let proj = linear(dim, dim, vb.pp("proj"))?;
+ Ok(Self {
+ qkv,
+ norm,
+ proj,
+ num_heads,
+ })
+ }
+
+ fn qkv(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
+ let qkv = xs.apply(&self.qkv)?;
+ let (b, l, _khd) = qkv.dims3()?;
+ let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?;
+ let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
+ let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
+ let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
+ let q = q.apply(&self.norm.query_norm)?;
+ let k = k.apply(&self.norm.key_norm)?;
+ Ok((q, k, v))
+ }
+
+ #[allow(unused)]
+ fn forward(&self, xs: &Tensor, pe: &Tensor) -> Result<Tensor> {
+ let (q, k, v) = self.qkv(xs)?;
+ attention(&q, &k, &v, pe)?.apply(&self.proj)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct Mlp {
+ lin1: Linear,
+ lin2: Linear,
+}
+
+impl Mlp {
+ fn new(in_sz: usize, mlp_sz: usize, vb: VarBuilder) -> Result<Self> {
+ let lin1 = linear(in_sz, mlp_sz, vb.pp("0"))?;
+ let lin2 = linear(mlp_sz, in_sz, vb.pp("2"))?;
+ Ok(Self { lin1, lin2 })
+ }
+}
+
+impl candle::Module for Mlp {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.apply(&self.lin1)?.gelu()?.apply(&self.lin2)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct DoubleStreamBlock {
+ img_mod: Modulation2,
+ img_norm1: LayerNorm,
+ img_attn: SelfAttention,
+ img_norm2: LayerNorm,
+ img_mlp: Mlp,
+ txt_mod: Modulation2,
+ txt_norm1: LayerNorm,
+ txt_attn: SelfAttention,
+ txt_norm2: LayerNorm,
+ txt_mlp: Mlp,
+}
+
+impl DoubleStreamBlock {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let h_sz = cfg.hidden_size;
+ let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize;
+ let img_mod = Modulation2::new(h_sz, vb.pp("img_mod"))?;
+ let img_norm1 = layer_norm(h_sz, vb.pp("img_norm1"))?;
+ let img_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("img_attn"))?;
+ let img_norm2 = layer_norm(h_sz, vb.pp("img_norm2"))?;
+ let img_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("img_mlp"))?;
+ let txt_mod = Modulation2::new(h_sz, vb.pp("txt_mod"))?;
+ let txt_norm1 = layer_norm(h_sz, vb.pp("txt_norm1"))?;
+ let txt_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("txt_attn"))?;
+ let txt_norm2 = layer_norm(h_sz, vb.pp("txt_norm2"))?;
+ let txt_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("txt_mlp"))?;
+ Ok(Self {
+ img_mod,
+ img_norm1,
+ img_attn,
+ img_norm2,
+ img_mlp,
+ txt_mod,
+ txt_norm1,
+ txt_attn,
+ txt_norm2,
+ txt_mlp,
+ })
+ }
+
+ fn forward(
+ &self,
+ img: &Tensor,
+ txt: &Tensor,
+ vec_: &Tensor,
+ pe: &Tensor,
+ ) -> Result<(Tensor, Tensor)> {
+ let (img_mod1, img_mod2) = self.img_mod.forward(vec_)?; // shift, scale, gate
+ let (txt_mod1, txt_mod2) = self.txt_mod.forward(vec_)?; // shift, scale, gate
+ let img_modulated = img.apply(&self.img_norm1)?;
+ let img_modulated = img_mod1.scale_shift(&img_modulated)?;
+ let (img_q, img_k, img_v) = self.img_attn.qkv(&img_modulated)?;
+
+ let txt_modulated = txt.apply(&self.txt_norm1)?;
+ let txt_modulated = txt_mod1.scale_shift(&txt_modulated)?;
+ let (txt_q, txt_k, txt_v) = self.txt_attn.qkv(&txt_modulated)?;
+
+ let q = Tensor::cat(&[txt_q, img_q], 2)?;
+ let k = Tensor::cat(&[txt_k, img_k], 2)?;
+ let v = Tensor::cat(&[txt_v, img_v], 2)?;
+
+ let attn = attention(&q, &k, &v, pe)?;
+ let txt_attn = attn.narrow(1, 0, txt.dim(1)?)?;
+ let img_attn = attn.narrow(1, txt.dim(1)?, attn.dim(1)? - txt.dim(1)?)?;
+
+ let img = (img + img_mod1.gate(&img_attn.apply(&self.img_attn.proj)?))?;
+ let img = (&img
+ + img_mod2.gate(
+ &img_mod2
+ .scale_shift(&img.apply(&self.img_norm2)?)?
+ .apply(&self.img_mlp)?,
+ )?)?;
+
+ let txt = (txt + txt_mod1.gate(&txt_attn.apply(&self.txt_attn.proj)?))?;
+ let txt = (&txt
+ + txt_mod2.gate(
+ &txt_mod2
+ .scale_shift(&txt.apply(&self.txt_norm2)?)?
+ .apply(&self.txt_mlp)?,
+ )?)?;
+
+ Ok((img, txt))
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct SingleStreamBlock {
+ linear1: Linear,
+ linear2: Linear,
+ norm: QkNorm,
+ pre_norm: LayerNorm,
+ modulation: Modulation1,
+ h_sz: usize,
+ mlp_sz: usize,
+ num_heads: usize,
+}
+
+impl SingleStreamBlock {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let h_sz = cfg.hidden_size;
+ let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize;
+ let head_dim = h_sz / cfg.num_heads;
+ let linear1 = linear(h_sz, h_sz * 3 + mlp_sz, vb.pp("linear1"))?;
+ let linear2 = linear(h_sz + mlp_sz, h_sz, vb.pp("linear2"))?;
+ let norm = QkNorm::new(head_dim, vb.pp("norm"))?;
+ let pre_norm = layer_norm(h_sz, vb.pp("pre_norm"))?;
+ let modulation = Modulation1::new(h_sz, vb.pp("modulation"))?;
+ Ok(Self {
+ linear1,
+ linear2,
+ norm,
+ pre_norm,
+ modulation,
+ h_sz,
+ mlp_sz,
+ num_heads: cfg.num_heads,
+ })
+ }
+
+ fn forward(&self, xs: &Tensor, vec_: &Tensor, pe: &Tensor) -> Result<Tensor> {
+ let mod_ = self.modulation.forward(vec_)?;
+ let x_mod = mod_.scale_shift(&xs.apply(&self.pre_norm)?)?;
+ let x_mod = x_mod.apply(&self.linear1)?;
+ let qkv = x_mod.narrow(D::Minus1, 0, 3 * self.h_sz)?;
+ let (b, l, _khd) = qkv.dims3()?;
+ let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?;
+ let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
+ let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
+ let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
+ let mlp = x_mod.narrow(D::Minus1, 3 * self.h_sz, self.mlp_sz)?;
+ let q = q.apply(&self.norm.query_norm)?;
+ let k = k.apply(&self.norm.key_norm)?;
+ let attn = attention(&q, &k, &v, pe)?;
+ let output = Tensor::cat(&[attn, mlp.gelu()?], 2)?.apply(&self.linear2)?;
+ xs + mod_.gate(&output)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct LastLayer {
+ norm_final: LayerNorm,
+ linear: Linear,
+ ada_ln_modulation: Linear,
+}
+
+impl LastLayer {
+ fn new(h_sz: usize, p_sz: usize, out_c: usize, vb: VarBuilder) -> Result<Self> {
+ let norm_final = layer_norm(h_sz, vb.pp("norm_final"))?;
+ let linear_ = linear(h_sz, p_sz * p_sz * out_c, vb.pp("linear"))?;
+ let ada_ln_modulation = linear(h_sz, 2 * h_sz, vb.pp("adaLN_modulation.1"))?;
+ Ok(Self {
+ norm_final,
+ linear: linear_,
+ ada_ln_modulation,
+ })
+ }
+
+ fn forward(&self, xs: &Tensor, vec: &Tensor) -> Result<Tensor> {
+ let chunks = vec.silu()?.apply(&self.ada_ln_modulation)?.chunk(2, 1)?;
+ let (shift, scale) = (&chunks[0], &chunks[1]);
+ let xs = xs
+ .apply(&self.norm_final)?
+ .broadcast_mul(&(scale.unsqueeze(1)? + 1.0)?)?
+ .broadcast_add(&shift.unsqueeze(1)?)?;
+ xs.apply(&self.linear)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct Flux {
+ img_in: Linear,
+ txt_in: Linear,
+ time_in: MlpEmbedder,
+ vector_in: MlpEmbedder,
+ guidance_in: Option<MlpEmbedder>,
+ pe_embedder: EmbedNd,
+ double_blocks: Vec<DoubleStreamBlock>,
+ single_blocks: Vec<SingleStreamBlock>,
+ final_layer: LastLayer,
+}
+
+impl Flux {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let img_in = linear(cfg.in_channels, cfg.hidden_size, vb.pp("img_in"))?;
+ let txt_in = linear(cfg.context_in_dim, cfg.hidden_size, vb.pp("txt_in"))?;
+ let mut double_blocks = Vec::with_capacity(cfg.depth);
+ let vb_d = vb.pp("double_blocks");
+ for idx in 0..cfg.depth {
+ let db = DoubleStreamBlock::new(cfg, vb_d.pp(idx))?;
+ double_blocks.push(db)
+ }
+ let mut single_blocks = Vec::with_capacity(cfg.depth_single_blocks);
+ let vb_s = vb.pp("single_blocks");
+ for idx in 0..cfg.depth_single_blocks {
+ let sb = SingleStreamBlock::new(cfg, vb_s.pp(idx))?;
+ single_blocks.push(sb)
+ }
+ let time_in = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("time_in"))?;
+ let vector_in = MlpEmbedder::new(cfg.vec_in_dim, cfg.hidden_size, vb.pp("vector_in"))?;
+ let guidance_in = if cfg.guidance_embed {
+ let mlp = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("guidance_in"))?;
+ Some(mlp)
+ } else {
+ None
+ };
+ let final_layer =
+ LastLayer::new(cfg.hidden_size, 1, cfg.in_channels, vb.pp("final_layer"))?;
+ let pe_dim = cfg.hidden_size / cfg.num_heads;
+ let pe_embedder = EmbedNd::new(pe_dim, cfg.theta, cfg.axes_dim.to_vec());
+ Ok(Self {
+ img_in,
+ txt_in,
+ time_in,
+ vector_in,
+ guidance_in,
+ pe_embedder,
+ double_blocks,
+ single_blocks,
+ final_layer,
+ })
+ }
+}
+
+impl super::WithForward for Flux {
+ #[allow(clippy::too_many_arguments)]
+ fn forward(
+ &self,
+ img: &Tensor,
+ img_ids: &Tensor,
+ txt: &Tensor,
+ txt_ids: &Tensor,
+ timesteps: &Tensor,
+ y: &Tensor,
+ guidance: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ if txt.rank() != 3 {
+ candle::bail!("unexpected shape for txt {:?}", txt.shape())
+ }
+ if img.rank() != 3 {
+ candle::bail!("unexpected shape for img {:?}", img.shape())
+ }
+ let dtype = img.dtype();
+ let pe = {
+ let ids = Tensor::cat(&[txt_ids, img_ids], 1)?;
+ ids.apply(&self.pe_embedder)?
+ };
+ let mut txt = txt.apply(&self.txt_in)?;
+ let mut img = img.apply(&self.img_in)?;
+ let vec_ = timestep_embedding(timesteps, 256, dtype)?.apply(&self.time_in)?;
+ let vec_ = match (self.guidance_in.as_ref(), guidance) {
+ (Some(g_in), Some(guidance)) => {
+ (vec_ + timestep_embedding(guidance, 256, dtype)?.apply(g_in))?
+ }
+ _ => vec_,
+ };
+ let vec_ = (vec_ + y.apply(&self.vector_in))?;
+
+ // Double blocks
+ for block in self.double_blocks.iter() {
+ (img, txt) = block.forward(&img, &txt, &vec_, &pe)?
+ }
+ // Single blocks
+ let mut img = Tensor::cat(&[&txt, &img], 1)?;
+ for block in self.single_blocks.iter() {
+ img = block.forward(&img, &vec_, &pe)?;
+ }
+ let img = img.i((.., txt.dim(1)?..))?;
+ self.final_layer.forward(&img, &vec_)
+ }
+}
diff --git a/candle-transformers/src/models/flux/sampling.rs b/candle-transformers/src/models/flux/sampling.rs
index 89b9a953..f3f0eafd 100644
--- a/candle-transformers/src/models/flux/sampling.rs
+++ b/candle-transformers/src/models/flux/sampling.rs
@@ -92,8 +92,8 @@ pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
}
#[allow(clippy::too_many_arguments)]
-pub fn denoise(
- model: &super::model::Flux,
+pub fn denoise<M: super::WithForward>(
+ model: &M,
img: &Tensor,
img_ids: &Tensor,
txt: &Tensor,