summaryrefslogtreecommitdiff
path: root/candle-examples
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-07 07:53:05 +0200
committerGitHub <noreply@github.com>2023-08-07 06:53:05 +0100
commit5bb2fce9985b8fa4bfa4f289ac0c76b8a3471121 (patch)
treee099a6cd5b763a36a5113070b497468608b21f01 /candle-examples
parent2c9f6059760c2c6bb62c6ceac3fd283f52d39fe8 (diff)
downloadcandle-5bb2fce9985b8fa4bfa4f289ac0c76b8a3471121.tar.gz
candle-5bb2fce9985b8fa4bfa4f289ac0c76b8a3471121.tar.bz2
candle-5bb2fce9985b8fa4bfa4f289ac0c76b8a3471121.zip
Implement group-norm. (#334)
* Implement group-norm. * Add some testing for group-norm.
Diffstat (limited to 'candle-examples')
-rw-r--r--candle-examples/examples/stable-diffusion/clip.rs3
-rw-r--r--candle-examples/examples/stable-diffusion/utils.rs5
2 files changed, 2 insertions, 6 deletions
diff --git a/candle-examples/examples/stable-diffusion/clip.rs b/candle-examples/examples/stable-diffusion/clip.rs
index be798ad0..227660b1 100644
--- a/candle-examples/examples/stable-diffusion/clip.rs
+++ b/candle-examples/examples/stable-diffusion/clip.rs
@@ -6,6 +6,7 @@
//!
//! https://github.com/openai/CLIP
use candle::{Device, Result, Tensor, D};
+use candle_nn as nn;
#[derive(Debug, Clone, Copy)]
pub enum Activation {
@@ -16,7 +17,7 @@ pub enum Activation {
impl Activation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
- Activation::QuickGelu => xs * crate::utils::sigmoid(&(xs * 1.702f64)?)?,
+ Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
Activation::Gelu => xs.gelu(),
}
}
diff --git a/candle-examples/examples/stable-diffusion/utils.rs b/candle-examples/examples/stable-diffusion/utils.rs
index 90fe3f9a..4294d823 100644
--- a/candle-examples/examples/stable-diffusion/utils.rs
+++ b/candle-examples/examples/stable-diffusion/utils.rs
@@ -1,10 +1,5 @@
use candle::{Device, Result, Tensor};
-pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
- // TODO: Add sigmoid as binary ops.
- (xs.neg()?.exp()? - 1.0)?.recip()
-}
-
pub fn avg_pool2d(_: &Tensor) -> Result<Tensor> {
todo!()
}