summaryrefslogtreecommitdiff
path: root/candle-nn/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-08 15:57:09 +0200
committerGitHub <noreply@github.com>2023-08-08 14:57:09 +0100
commit89d3926c9b0f497b48624f2719df6091e5d8785c (patch)
tree2ea4a853372ac48b1e6660dd5872958c6c05d54d /candle-nn/src
parentab3568432608316b89791eaa4085a5cb519fe6c3 (diff)
downloadcandle-89d3926c9b0f497b48624f2719df6091e5d8785c.tar.gz
candle-89d3926c9b0f497b48624f2719df6091e5d8785c.tar.bz2
candle-89d3926c9b0f497b48624f2719df6091e5d8785c.zip
Fixes for the stable diffusion example. (#342)
* Fixes for the stable diffusion example. * Bugfix. * Another fix. * Fix for group-norm. * More fixes to get SD to work.
Diffstat (limited to 'candle-nn/src')
-rw-r--r--candle-nn/src/group_norm.rs12
1 files changed, 8 insertions, 4 deletions
diff --git a/candle-nn/src/group_norm.rs b/candle-nn/src/group_norm.rs
index e277ae85..ac77db4b 100644
--- a/candle-nn/src/group_norm.rs
+++ b/candle-nn/src/group_norm.rs
@@ -59,17 +59,21 @@ impl GroupNorm {
let x = x.broadcast_sub(&mean_x)?;
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
+ let mut w_dims = vec![1; x_shape.len()];
+ w_dims[1] = n_channels;
+ let weight = self.weight.reshape(w_dims.clone())?;
+ let bias = self.bias.reshape(w_dims)?;
x_normed
.to_dtype(x_dtype)?
- .broadcast_mul(&self.weight)?
- .broadcast_add(&self.bias)?
- .reshape(x_shape)
+ .reshape(x_shape)?
+ .broadcast_mul(&weight)?
+ .broadcast_add(&bias)
}
}
pub fn group_norm(
- num_channels: usize,
num_groups: usize,
+ num_channels: usize,
eps: f64,
vb: crate::VarBuilder,
) -> Result<GroupNorm> {