summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/wuerstchen/main.rs1
-rw-r--r--candle-transformers/src/models/stable_diffusion/clip.rs6
-rw-r--r--candle-transformers/src/models/wuerstchen/common.rs4
-rw-r--r--candle-transformers/src/models/wuerstchen/ddpm.rs6
4 files changed, 10 insertions, 7 deletions
diff --git a/candle-examples/examples/wuerstchen/main.rs b/candle-examples/examples/wuerstchen/main.rs
index 8064f87f..bce68114 100644
--- a/candle-examples/examples/wuerstchen/main.rs
+++ b/candle-examples/examples/wuerstchen/main.rs
@@ -373,7 +373,6 @@ fn run(args: Args) -> Result<()> {
);
let image = vqgan.decode(&(&latents * 0.3764)?)?;
// TODO: Add the clamping between 0 and 1.
- let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
candle_examples::save_image(&image, image_filename)?
diff --git a/candle-transformers/src/models/stable_diffusion/clip.rs b/candle-transformers/src/models/stable_diffusion/clip.rs
index 7f86cf31..e7a20270 100644
--- a/candle-transformers/src/models/stable_diffusion/clip.rs
+++ b/candle-transformers/src/models/stable_diffusion/clip.rs
@@ -12,6 +12,7 @@ use candle_nn::Module;
pub enum Activation {
QuickGelu,
Gelu,
+ GeluErf,
}
impl Module for Activation {
@@ -19,6 +20,7 @@ impl Module for Activation {
match self {
Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
Activation::Gelu => xs.gelu(),
+ Activation::GeluErf => xs.gelu_erf(),
}
}
}
@@ -111,7 +113,7 @@ impl Config {
num_hidden_layers: 24,
num_attention_heads: 16,
projection_dim: 1024,
- activation: Activation::Gelu,
+ activation: Activation::GeluErf,
}
}
@@ -126,7 +128,7 @@ impl Config {
num_hidden_layers: 32,
num_attention_heads: 20,
projection_dim: 512,
- activation: Activation::Gelu,
+ activation: Activation::GeluErf,
}
}
}
diff --git a/candle-transformers/src/models/wuerstchen/common.rs b/candle-transformers/src/models/wuerstchen/common.rs
index 3cac2a59..8416a1f1 100644
--- a/candle-transformers/src/models/wuerstchen/common.rs
+++ b/candle-transformers/src/models/wuerstchen/common.rs
@@ -100,7 +100,7 @@ impl GlobalResponseNorm {
impl Module for GlobalResponseNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?;
+ let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
let stand_div_norm =
agg_norm.broadcast_div(&(agg_norm.mean_keepdim(D::Minus1)? + 1e-6)?)?;
xs.broadcast_mul(&stand_div_norm)?
@@ -152,7 +152,7 @@ impl ResBlock {
.permute((0, 2, 3, 1))?;
let xs = xs
.apply(&self.channelwise_lin1)?
- .gelu()?
+ .gelu_erf()?
.apply(&self.channelwise_grn)?
.apply(&self.channelwise_lin2)?
.permute((0, 3, 1, 2))?;
diff --git a/candle-transformers/src/models/wuerstchen/ddpm.rs b/candle-transformers/src/models/wuerstchen/ddpm.rs
index 80640072..9e69b868 100644
--- a/candle-transformers/src/models/wuerstchen/ddpm.rs
+++ b/candle-transformers/src/models/wuerstchen/ddpm.rs
@@ -52,8 +52,10 @@ impl DDPMWScheduler {
} else {
t
};
- let alpha_cumprod =
- ((t + s) / (1. + s) * std::f64::consts::PI * 0.5).powi(2) / self.init_alpha_cumprod;
+ let alpha_cumprod = ((t + s) / (1. + s) * std::f64::consts::PI * 0.5)
+ .cos()
+ .powi(2)
+ / self.init_alpha_cumprod;
alpha_cumprod.clamp(0.0001, 0.9999)
}