diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-06 18:49:43 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-06 17:49:43 +0100 |
commit | d34039e35267b3f4de83770f8da4ea31491bcec5 (patch) | |
tree | 6efc4796859a04223d2211303cc12b3a97321fcc /candle-examples/examples/stable-diffusion/embeddings.rs | |
parent | 93cfe5642f473889d1df62ccb8f1740f77523dd3 (diff) | |
download | candle-d34039e35267b3f4de83770f8da4ea31491bcec5.tar.gz candle-d34039e35267b3f4de83770f8da4ea31491bcec5.tar.bz2 candle-d34039e35267b3f4de83770f8da4ea31491bcec5.zip |
Add a stable diffusion example (#328)
* Start adding a stable-diffusion example.
* Proper computation of the causal mask.
* Add the chunk operation.
* Work in progress: port the attention module.
* Add some dummy modules for conv2d and group-norm, get the attention module to compile.
* Re-enable the 2d convolution.
* Add the embeddings module.
* Add the resnet module.
* Add the unet blocks.
* Add the unet.
* And add the variational auto-encoder.
* Use the pad function from utils.
Diffstat (limited to 'candle-examples/examples/stable-diffusion/embeddings.rs')
-rw-r--r-- | candle-examples/examples/stable-diffusion/embeddings.rs | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/candle-examples/examples/stable-diffusion/embeddings.rs b/candle-examples/examples/stable-diffusion/embeddings.rs new file mode 100644 index 00000000..f8a4f351 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/embeddings.rs @@ -0,0 +1,65 @@ +#![allow(dead_code)] +use candle::{Result, Tensor, D}; +use candle_nn as nn; + +#[derive(Debug)] +pub struct TimestepEmbedding { + linear_1: nn::Linear, + linear_2: nn::Linear, +} + +impl TimestepEmbedding { + // act_fn: "silu" + pub fn new(vs: nn::VarBuilder, channel: usize, time_embed_dim: usize) -> Result<Self> { + let linear_1 = nn::linear(channel, time_embed_dim, vs.pp("linear_1"))?; + let linear_2 = nn::linear(time_embed_dim, time_embed_dim, vs.pp("linear_2"))?; + Ok(Self { linear_1, linear_2 }) + } +} + +impl TimestepEmbedding { + pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let xs = nn::ops::silu(&self.linear_1.forward(xs)?)?; + self.linear_2.forward(&xs) + } +} + +#[derive(Debug)] +pub struct Timesteps { + num_channels: usize, + flip_sin_to_cos: bool, + downscale_freq_shift: f64, +} + +impl Timesteps { + pub fn new(num_channels: usize, flip_sin_to_cos: bool, downscale_freq_shift: f64) -> Self { + Self { + num_channels, + flip_sin_to_cos, + downscale_freq_shift, + } + } +} + +impl Timesteps { + pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let half_dim = (self.num_channels / 2) as u32; + let exponent = + (Tensor::arange(0, half_dim, xs.device())?.to_dtype(xs.dtype())? * -f64::ln(10000.))?; + let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?; + let emb = exponent.exp()?; + // emb = timesteps[:, None].float() * emb[None, :] + let emb = (xs.unsqueeze(D::Minus1)? * emb.unsqueeze(0)?)?; + let (cos, sin) = (emb.cos()?, emb.sin()?); + let emb = if self.flip_sin_to_cos { + Tensor::cat(&[&cos, &sin], D::Minus1)? + } else { + Tensor::cat(&[&sin, &cos], D::Minus1)? + }; + if self.num_channels % 2 == 1 { + crate::utils::pad(&emb) // ([0, 1, 0, 0], 'constant', None) + } else { + Ok(emb) + } + } +} |