diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-09-26 10:23:43 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-26 10:23:43 +0200 |
commit | 10d47183c088ce449da13d74f07171c8106cd6dd (patch) | |
tree | b91b0398fcb314e998b9f7f3b23877f63462b232 /candle-transformers/src/models/flux/model.rs | |
parent | d01207dbf3fb0ad614e7915c8f5706fbc09902fb (diff) | |
download | candle-10d47183c088ce449da13d74f07171c8106cd6dd.tar.gz candle-10d47183c088ce449da13d74f07171c8106cd6dd.tar.bz2 candle-10d47183c088ce449da13d74f07171c8106cd6dd.zip |
Quantized version of flux. (#2500)
* Quantized version of flux.
* More generic sampling.
* Hook the quantized model.
* Use the newly minted gguf file.
* Fix for the quantized model.
* Default to avoid the faster cuda kernels.
Diffstat (limited to 'candle-transformers/src/models/flux/model.rs')
-rw-r--r-- | candle-transformers/src/models/flux/model.rs | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/candle-transformers/src/models/flux/model.rs b/candle-transformers/src/models/flux/model.rs index 4e47873f..17b4eb25 100644 --- a/candle-transformers/src/models/flux/model.rs +++ b/candle-transformers/src/models/flux/model.rs @@ -109,14 +109,14 @@ fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result<Tensor> { (fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec()) } -fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> { +pub(crate) fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> { let q = apply_rope(q, pe)?.contiguous()?; let k = apply_rope(k, pe)?.contiguous()?; let x = scaled_dot_product_attention(&q, &k, v)?; x.transpose(1, 2)?.flatten_from(2) } -fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> { +pub(crate) fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> { const TIME_FACTOR: f64 = 1000.; const MAX_PERIOD: f64 = 10000.; if dim % 2 == 1 { @@ -144,7 +144,7 @@ pub struct EmbedNd { } impl EmbedNd { - fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self { + pub fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self { Self { dim, theta, @@ -575,9 +575,11 @@ impl Flux { final_layer, }) } +} +impl super::WithForward for Flux { #[allow(clippy::too_many_arguments)] - pub fn forward( + fn forward( &self, img: &Tensor, img_ids: &Tensor, |