summaryrefslogtreecommitdiff
path: root/candle-examples
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-06 22:14:52 +0200
committerGitHub <noreply@github.com>2023-08-06 21:14:52 +0100
commit166bfd5847144abec227836e497b509625470535 (patch)
tree7b13e3dae76c0864a3cb107c98b3a88f24423af3 /candle-examples
parent1c062bf06ba504a076b329c965c625be0ec67c1d (diff)
downloadcandle-166bfd5847144abec227836e497b509625470535.tar.gz
candle-166bfd5847144abec227836e497b509625470535.tar.bz2
candle-166bfd5847144abec227836e497b509625470535.zip
Add the recip op + use it in stable-diffusion. (#331)
* Add the recip unary op. * Fix the cuda kernel. * Use the recip op in sigmoid.
Diffstat (limited to 'candle-examples')
-rw-r--r--candle-examples/examples/stable-diffusion/utils.rs18
1 files changed, 13 insertions, 5 deletions
diff --git a/candle-examples/examples/stable-diffusion/utils.rs b/candle-examples/examples/stable-diffusion/utils.rs
index 50ee48e9..90fe3f9a 100644
--- a/candle-examples/examples/stable-diffusion/utils.rs
+++ b/candle-examples/examples/stable-diffusion/utils.rs
@@ -1,7 +1,8 @@
-use candle::{Result, Tensor};
+use candle::{Device, Result, Tensor};
-pub fn sigmoid(_: &Tensor) -> Result<Tensor> {
- todo!()
+pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
+ // TODO: Add sigmoid as binary ops.
+ (xs.neg()?.exp()? - 1.0)?.recip()
}
pub fn avg_pool2d(_: &Tensor) -> Result<Tensor> {
@@ -16,6 +17,13 @@ pub fn upsample_nearest2d(_: &Tensor) -> Result<Tensor> {
todo!()
}
-pub fn linspace(_: f64, _: f64, _: usize) -> Result<Tensor> {
- todo!()
+pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
+ if steps < 1 {
+ candle::bail!("cannot use linspace with steps {steps} <= 1")
+ }
+ let delta = (stop - start) / (steps - 1) as f64;
+ let vs = (0..steps)
+ .map(|step| start + step as f64 * delta)
+ .collect::<Vec<_>>();
+ Tensor::from_vec(vs, steps, &Device::Cpu)
}