diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-07 07:53:05 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-07 06:53:05 +0100 |
commit | 5bb2fce9985b8fa4bfa4f289ac0c76b8a3471121 (patch) | |
tree | e099a6cd5b763a36a5113070b497468608b21f01 /candle-examples | |
parent | 2c9f6059760c2c6bb62c6ceac3fd283f52d39fe8 (diff) | |
download | candle-5bb2fce9985b8fa4bfa4f289ac0c76b8a3471121.tar.gz candle-5bb2fce9985b8fa4bfa4f289ac0c76b8a3471121.tar.bz2 candle-5bb2fce9985b8fa4bfa4f289ac0c76b8a3471121.zip |
Implement group-norm. (#334)
* Implement group-norm.
* Add some testing for group-norm.
Diffstat (limited to 'candle-examples')
-rw-r--r-- | candle-examples/examples/stable-diffusion/clip.rs | 3 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/utils.rs | 5 |
2 files changed, 2 insertions, 6 deletions
diff --git a/candle-examples/examples/stable-diffusion/clip.rs b/candle-examples/examples/stable-diffusion/clip.rs index be798ad0..227660b1 100644 --- a/candle-examples/examples/stable-diffusion/clip.rs +++ b/candle-examples/examples/stable-diffusion/clip.rs @@ -6,6 +6,7 @@ //! //! https://github.com/openai/CLIP use candle::{Device, Result, Tensor, D}; +use candle_nn as nn; #[derive(Debug, Clone, Copy)] pub enum Activation { @@ -16,7 +17,7 @@ pub enum Activation { impl Activation { fn forward(&self, xs: &Tensor) -> Result<Tensor> { match self { - Activation::QuickGelu => xs * crate::utils::sigmoid(&(xs * 1.702f64)?)?, + Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?, Activation::Gelu => xs.gelu(), } } diff --git a/candle-examples/examples/stable-diffusion/utils.rs b/candle-examples/examples/stable-diffusion/utils.rs index 90fe3f9a..4294d823 100644 --- a/candle-examples/examples/stable-diffusion/utils.rs +++ b/candle-examples/examples/stable-diffusion/utils.rs @@ -1,10 +1,5 @@ use candle::{Device, Result, Tensor}; -pub fn sigmoid(xs: &Tensor) -> Result<Tensor> { - // TODO: Add sigmoid as binary ops. - (xs.neg()?.exp()? - 1.0)?.recip() -} - pub fn avg_pool2d(_: &Tensor) -> Result<Tensor> { todo!() } |