summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stable-diffusion/utils.rs
blob: 08b78c0402e74d6353c6ed12b721e422b53ff119 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
use candle::{Device, Result, Tensor};

pub fn avg_pool2d(_: &Tensor) -> Result<Tensor> {
    todo!()
}

pub fn upsample_nearest2d(_: &Tensor) -> Result<Tensor> {
    todo!()
}

pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
    if steps < 1 {
        candle::bail!("cannot use linspace with steps {steps} <= 1")
    }
    let delta = (stop - start) / (steps - 1) as f64;
    let vs = (0..steps)
        .map(|step| start + step as f64 * delta)
        .collect::<Vec<_>>();
    Tensor::from_vec(vs, steps, &Device::Cpu)
}