summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/flux/model.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-09-26 10:23:43 +0200
committerGitHub <noreply@github.com>2024-09-26 10:23:43 +0200
commit10d47183c088ce449da13d74f07171c8106cd6dd (patch)
treeb91b0398fcb314e998b9f7f3b23877f63462b232 /candle-transformers/src/models/flux/model.rs
parentd01207dbf3fb0ad614e7915c8f5706fbc09902fb (diff)
downloadcandle-10d47183c088ce449da13d74f07171c8106cd6dd.tar.gz
candle-10d47183c088ce449da13d74f07171c8106cd6dd.tar.bz2
candle-10d47183c088ce449da13d74f07171c8106cd6dd.zip
Quantized version of flux. (#2500)
* Quantized version of flux. * More generic sampling. * Hook the quantized model. * Use the newly minted gguf file. * Fix for the quantized model. * Default to avoid the faster cuda kernels.
Diffstat (limited to 'candle-transformers/src/models/flux/model.rs')
-rw-r--r--candle-transformers/src/models/flux/model.rs10
1 files changed, 6 insertions, 4 deletions
diff --git a/candle-transformers/src/models/flux/model.rs b/candle-transformers/src/models/flux/model.rs
index 4e47873f..17b4eb25 100644
--- a/candle-transformers/src/models/flux/model.rs
+++ b/candle-transformers/src/models/flux/model.rs
@@ -109,14 +109,14 @@ fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result<Tensor> {
(fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec())
}
-fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
+pub(crate) fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
let q = apply_rope(q, pe)?.contiguous()?;
let k = apply_rope(k, pe)?.contiguous()?;
let x = scaled_dot_product_attention(&q, &k, v)?;
x.transpose(1, 2)?.flatten_from(2)
}
-fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
+pub(crate) fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
const TIME_FACTOR: f64 = 1000.;
const MAX_PERIOD: f64 = 10000.;
if dim % 2 == 1 {
@@ -144,7 +144,7 @@ pub struct EmbedNd {
}
impl EmbedNd {
- fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
+ pub fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
Self {
dim,
theta,
@@ -575,9 +575,11 @@ impl Flux {
final_layer,
})
}
+}
+impl super::WithForward for Flux {
#[allow(clippy::too_many_arguments)]
- pub fn forward(
+ fn forward(
&self,
img: &Tensor,
img_ids: &Tensor,