summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore3
-rw-r--r--candle-core/src/backprop.rs62
-rw-r--r--candle-core/src/conv.rs12
-rw-r--r--candle-core/src/cpu_backend.rs87
-rw-r--r--candle-core/src/cuda_backend.rs66
-rw-r--r--candle-core/src/display.rs2
-rw-r--r--candle-core/src/quantized/avx.rs126
-rw-r--r--candle-core/src/quantized/k_quants.rs3
-rw-r--r--candle-core/tests/conv_tests.rs138
-rw-r--r--candle-examples/examples/stable-diffusion/attention.rs8
-rw-r--r--candle-examples/examples/stable-diffusion/clip.rs30
-rw-r--r--candle-examples/examples/stable-diffusion/main.rs170
-rw-r--r--candle-examples/examples/stable-diffusion/stable_diffusion.rs128
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d.rs25
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d_blocks.rs12
-rw-r--r--candle-kernels/src/conv.cu88
-rw-r--r--candle-nn/src/var_builder.rs44
17 files changed, 851 insertions, 153 deletions
diff --git a/.gitignore b/.gitignore
index 85dc61c0..2748d37e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -32,4 +32,5 @@ candle-wasm-examples/*/*.wav
candle-wasm-examples/*/*.safetensors
candle-wasm-examples/*/package-lock.json
-.DS_Store \ No newline at end of file
+.DS_Store
+.idea/*
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 22c28ac4..9ecdee4f 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -192,12 +192,68 @@ impl Tensor {
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
- Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?,
+ Op::Conv2D {
+ arg,
+ kernel,
+ padding,
+ stride,
+ } => {
+ // The output height for conv_transpose2d is:
+ // (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1
+ let grad_h = grad.dim(2)?;
+ let k_h = kernel.dim(2)?;
+ let out_size = (grad_h - 1) * stride + (k_h - 1) + 1 - 2 * padding;
+ let out_padding = arg.dim(2)? - out_size;
+ let grad_arg =
+ grad.conv_transpose2d(kernel, *padding, out_padding, *stride)?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&grad_arg)?;
+
+ let grad_kernel = arg
+ .transpose(0, 1)?
+ .conv2d(&grad.transpose(0, 1)?, *padding, *stride, 1)?
+ .transpose(0, 1)?;
+ let sum_grad = grads.or_insert(kernel)?;
+ *sum_grad = sum_grad.add(&grad_kernel)?;
+ }
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
op: "conv-transpose2d",
})?,
- Op::AvgPool2D { .. } => Err(Error::BackwardNotSupported { op: "avg-pool2d" })?,
- Op::MaxPool2D { .. } => Err(Error::BackwardNotSupported { op: "max-pool2d" })?,
+ Op::AvgPool2D {
+ arg,
+ kernel_size,
+ stride,
+ } => {
+ if kernel_size != stride {
+ crate::bail!("backward not supported for avgpool2d if ksize {kernel_size:?} != stride {stride:?}")
+ }
+ let (_n, _c, h, w) = arg.dims4()?;
+ let grad_arg = grad.upsample_nearest2d(h, w)?;
+ let grad_arg =
+ (grad_arg * (1f64 / (kernel_size.0 * kernel_size.1) as f64))?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&grad_arg)?;
+ }
+ Op::MaxPool2D {
+ arg,
+ kernel_size,
+ stride,
+ } => {
+ if kernel_size != stride {
+ crate::bail!("backward not supported for maxpool2d if ksize {kernel_size:?} != stride {stride:?}")
+ }
+ let (_n, _c, h, w) = arg.dims4()?;
+ // For computing the max-pool gradient, we compute a mask where a 1 means
+ // that the element is the maximum, then we apply this mask to the
+ // upsampled gradient (taking into account that multiple max may exist so
+ // we scale the gradient for this case).
+ let node_upsampled = node.upsample_nearest2d(h, w)?;
+ let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?;
+ let avg = mask.avg_pool2d(*kernel_size, *stride)?;
+ let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&grad_arg)?;
+ }
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest2d",
})?,
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs
index 3455247b..d9e0a9ab 100644
--- a/candle-core/src/conv.rs
+++ b/candle-core/src/conv.rs
@@ -71,18 +71,14 @@ pub struct ParamsConvTranspose2D {
impl ParamsConvTranspose2D {
pub(crate) fn out_h(&self) -> usize {
let dilation = 1;
- (self.i_h - 1) * self.stride - 2 * self.padding
- + dilation * (self.k_h - 1)
- + self.output_padding
- + 1
+ (self.i_h - 1) * self.stride + dilation * (self.k_h - 1) + self.output_padding + 1
+ - 2 * self.padding
}
pub(crate) fn out_w(&self) -> usize {
let dilation = 1;
- (self.i_w - 1) * self.stride - 2 * self.padding
- + dilation * (self.k_w - 1)
- + self.output_padding
- + 1
+ (self.i_w - 1) * self.stride + dilation * (self.k_w - 1) + self.output_padding + 1
+ - 2 * self.padding
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 0b19904b..f52d53b1 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -1193,41 +1193,78 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
let (out_h, out_w) = (p.out_h(), p.out_w());
// Output shape: [b_size, c_out, out_h, out_w].
- let mut dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
+ let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
let dst_s0 = p.c_out * out_h * out_w;
let dst_s1 = out_h * out_w;
let dst_s2 = out_w;
let dst_s3 = 1;
+
+ // TODO: Avoid making this copy if `inp` already has the appropriate layout.
+ let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
+ let cont_s0 = p.i_h * p.i_w * p.c_in;
+ let cont_s1 = p.i_w * p.c_in;
+ let cont_s2 = p.c_in;
for b_idx in 0..p.b_size {
- for out_y in 0..out_h as i32 {
- for out_x in 0..out_w as i32 {
- let inp_x = out_x * p.stride as i32 - p.padding as i32;
- let inp_y = out_y * p.stride as i32 - p.padding as i32;
- for k_y in 0..p.k_h as i32 {
- for k_x in 0..p.k_h as i32 {
- let k_index = k_y as usize * k_s2 + k_x as usize * k_s3;
- let inp_y = inp_y + k_y;
- let inp_x = inp_x + k_x;
- if inp_x < 0 || inp_y < 0 {
- continue;
- }
- let inp_x = inp_x as usize;
- let inp_y = inp_y as usize;
- if inp_x < p.i_w && inp_y < p.i_h {
- let inp_index = b_idx * inp_s0 + inp_y * inp_s2 + inp_x * inp_s3;
- let dst_index = b_idx * dst_s0 + inp_y * dst_s2 + inp_x * dst_s3;
- for c_out in 0..k_s0 {
- for c_in in 0..k_s1 {
- let k_index = k_index + c_out * k_s1 + c_in * k_s0;
- let dst_index = dst_index + c_out * dst_s1;
- let inp_index = inp_index + c_in * inp_s1;
- dst[dst_index] += k[k_index] * inp[inp_index]
+ for h_idx in 0..p.i_h {
+ for w_idx in 0..p.i_w {
+ for c_idx in 0..p.c_in {
+ let src_idx =
+ b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
+ let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
+ inp_cont[dst_idx] = inp[src_idx]
+ }
+ }
+ }
+ }
+ let num_threads = crate::utils::get_num_threads();
+
+ for k_y in 0..p.k_h {
+ for k_x in 0..p.k_w {
+ crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
+ let k_cont = (0..p.c_in)
+ .map(|c_in_idx| {
+ k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3]
+ })
+ .collect::<Vec<_>>();
+ for b_idx in 0..p.b_size {
+ for inp_y in 0..p.i_h {
+ for inp_x in 0..p.i_w {
+ let out_x = inp_x * p.stride + k_x;
+ let out_y = inp_y * p.stride + k_y;
+ if out_x < p.padding || out_y < p.padding {
+ continue;
+ }
+ let out_x = out_x - p.padding;
+ let out_y = out_y - p.padding;
+ if out_x < out_w && out_y < out_h {
+ let inp_cont = &inp_cont
+ [b_idx * cont_s0 + inp_y * cont_s1 + inp_x * cont_s2..];
+ let dst_idx = b_idx * dst_s0
+ + out_y * dst_s2
+ + out_x * dst_s3
+ + dst_c_idx * dst_s1;
+ let mut d = T::zero();
+ unsafe {
+ T::vec_dot(
+ inp_cont.as_ptr(),
+ k_cont.as_ptr(),
+ &mut d,
+ p.c_in,
+ )
+ }
+ let dst_p = dst.as_ptr();
+ // Safety: dst_idx are uniques per dst_c_idx which is used to
+ // parallelise the different tasks so no two threads can try to
+ // write at the same location.
+ unsafe {
+ let ptr = dst_p.add(dst_idx) as *mut T;
+ *ptr += d
}
}
}
}
}
- }
+ })
}
}
Ok(dst)
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 75eaf70a..ed696368 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -977,8 +977,8 @@ impl<'a> Map2 for Conv2D<'a> {
k_l: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
- // Kernel shape: (c_out, c_in_k, w_k, h_k)
- // Input shape: (b_size, c_in, w_in, c_in)
+ // Kernel shape: (c_out, c_in_k, h_k, w_k)
+ // Input shape: (b_size, c_in, h_in, w_in)
let p = &self.0;
let (out_w, out_h) = (p.out_w(), p.out_h());
let dst_el = p.c_out * out_w * out_h * p.b_size;
@@ -1005,6 +1005,55 @@ impl<'a> Map2 for Conv2D<'a> {
}
}
+struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
+impl<'a> Map2 for ConvTranspose2D<'a> {
+ fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
+ &self,
+ inp: &CudaSlice<T>,
+ inp_l: &Layout,
+ k: &CudaSlice<T>,
+ k_l: &Layout,
+ dev: &CudaDevice,
+ ) -> Result<CudaSlice<T>> {
+ // Kernel shape: (c_in_k, c_out, h_k, w_k)
+ // Input shape: (b_size, c_in, h_in, w_in)
+ let p = &self.0;
+ let (out_w, out_h) = (p.out_w(), p.out_h());
+ let dst_el = p.c_out * out_w * out_h * p.b_size;
+ let inp = &inp.slice(inp_l.start_offset()..);
+ let k = &k.slice(k_l.start_offset()..);
+ let shape = inp_l.shape();
+ let dims = shape.dims();
+ let el = shape.elem_count();
+
+ // SAFETY: Set later by running the kernel.
+ let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
+ let cfg = LaunchConfig::for_num_elems(dst_el as u32);
+ let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose2d"), kernels::CONV)?;
+ let ds = if dims.len() == 4 {
+ [dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
+ } else {
+ crate::bail!("unexpected input shape for conv_transpose2d {dims:?}")
+ };
+ let ds = dev.htod_copy(ds).w()?;
+ let params = (
+ el,
+ out_w,
+ out_h,
+ p.stride,
+ p.padding,
+ p.output_padding,
+ &ds,
+ inp,
+ k,
+ &out,
+ );
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(out)
+ }
+}
+
enum PoolOp {
Max,
Avg,
@@ -1649,12 +1698,15 @@ impl BackendStorage for CudaStorage {
fn conv_transpose2d(
&self,
- _l: &Layout,
- _kernel: &Self,
- _kernel_l: &Layout,
- _params: &crate::conv::ParamsConvTranspose2D,
+ l: &Layout,
+ kernel: &Self,
+ kernel_l: &Layout,
+ params: &crate::conv::ParamsConvTranspose2D,
) -> Result<Self> {
- todo!()
+ let device = self.device().clone();
+ let slice =
+ ConvTranspose2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
+ Ok(Self { slice, device })
}
fn avg_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result<Self> {
diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs
index 8390a4a0..b497699b 100644
--- a/candle-core/src/display.rs
+++ b/candle-core/src/display.rs
@@ -43,7 +43,7 @@ impl Tensor {
}
}
}
- write!(f, "; {} ,{}]", self.dtype().as_str(), device_str)
+ write!(f, "; {}{}]", self.dtype().as_str(), device_str)
}
}
diff --git a/candle-core/src/quantized/avx.rs b/candle-core/src/quantized/avx.rs
index 96087feb..f906d090 100644
--- a/candle-core/src/quantized/avx.rs
+++ b/candle-core/src/quantized/avx.rs
@@ -1,5 +1,6 @@
-use super::k_quants::{BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K};
+use super::k_quants::{BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K};
use crate::Result;
+use byteorder::{ByteOrder, LittleEndian};
use half::f16;
#[cfg(target_arch = "x86")]
@@ -89,17 +90,35 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
}
}
-const K_SHUFFLE: [u8; 128] = [
- 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
- 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
- 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11,
- 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13, 13, 14, 14, 14, 14,
- 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15,
-];
-
+#[inline(always)]
unsafe fn get_scale_shuffle(i: usize) -> __m128i {
+ const K_SHUFFLE: [u8; 128] = [
+ 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3,
+ 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7,
+ 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10,
+ 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13,
+ 13, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15,
+ ];
_mm_loadu_si128((K_SHUFFLE.as_ptr() as *const __m128i).add(i))
}
+
+#[inline(always)]
+unsafe fn get_scale_shuffle_k4(i: usize) -> __m256i {
+ const K_SHUFFLE: [u8; 256] = [
+ 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
+ 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
+ 2, 3, 2, 3, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
+ 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
+ 6, 7, 6, 7, 6, 7, 6, 7, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
+ 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10,
+ 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 12, 13, 12, 13, 12, 13,
+ 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12,
+ 13, 12, 13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
+ 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
+ ];
+ _mm256_loadu_si256((K_SHUFFLE.as_ptr() as *const __m256i).add(i))
+}
+
#[inline(always)]
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
let qk = QK_K;
@@ -187,3 +206,92 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
Ok(hsum_float_8(acc))
}
}
+
+#[inline(always)]
+unsafe fn mm256_set_m128i(a: __m128i, b: __m128i) -> __m256i {
+ _mm256_insertf128_si256(_mm256_castsi128_si256(b), a, 1)
+}
+
+#[inline(always)]
+pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
+ if n % QK_K != 0 {
+ crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
+ }
+ let mut utmp = [0u32; 4];
+ let kmask1: u32 = 0x3f3f3f3f;
+ let kmask2: u32 = 0x0f0f0f0f;
+ let kmask3: u32 = 0x03030303;
+
+ unsafe {
+ let m4 = _mm256_set1_epi8(0xF);
+
+ let mut acc = _mm256_setzero_ps();
+ let mut acc_m = _mm_setzero_ps();
+
+ for (x, y) in xs.iter().zip(ys.iter()) {
+ let d = y.d * x.d.to_f32();
+ let dmin = -y.d * x.dmin.to_f32();
+
+ LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
+
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+ let uaux = utmp[1] & kmask1;
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[2] = uaux;
+ utmp[0] &= kmask1;
+
+ let mut q4 = x.qs.as_ptr();
+ let mut q8 = y.qs.as_ptr();
+
+ let mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(
+ utmp[3] as i32,
+ utmp[2] as i32,
+ utmp[1] as i32,
+ utmp[0] as i32,
+ ));
+
+ let q8sums = _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i);
+ let q8s = _mm_hadd_epi16(
+ _mm256_extracti128_si256(q8sums, 0),
+ _mm256_extracti128_si256(q8sums, 1),
+ );
+ let prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
+ acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
+
+ let sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
+ let scales = mm256_set_m128i(sc128, sc128);
+
+ let mut sumi = _mm256_setzero_si256();
+
+ for j in 0..QK_K / 64 {
+ let scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j));
+ let scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j + 1));
+
+ let q4bits = _mm256_loadu_si256(q4 as *const __m256i);
+ q4 = q4.add(32);
+ let q4l = _mm256_and_si256(q4bits, m4);
+ let q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
+
+ let q8l = _mm256_loadu_si256(q8 as *const __m256i);
+ q8 = q8.add(32);
+ let p16l = _mm256_maddubs_epi16(q4l, q8l);
+ let p16l = _mm256_madd_epi16(scale_l, p16l);
+ sumi = _mm256_add_epi32(sumi, p16l);
+
+ let q8h = _mm256_loadu_si256(q8 as *const __m256i);
+ q8 = q8.add(32);
+ let p16h = _mm256_maddubs_epi16(q4h, q8h);
+ let p16h = _mm256_madd_epi16(scale_h, p16h);
+ sumi = _mm256_add_epi32(sumi, p16h);
+ }
+
+ let vd = _mm256_set1_ps(d);
+ acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
+ }
+
+ let acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
+ let acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
+
+ Ok(hsum_float_8(acc) + _mm_cvtss_f32(acc_m))
+ }
+}
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs
index 7b405ec9..7f14600b 100644
--- a/candle-core/src/quantized/k_quants.rs
+++ b/candle-core/src/quantized/k_quants.rs
@@ -1104,6 +1104,9 @@ impl GgmlType for BlockQ4K {
#[allow(unreachable_code)]
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
+ #[cfg(target_feature = "avx")]
+ return super::avx::vec_dot_q4k_q8k(n, xs, ys);
+
#[cfg(target_feature = "neon")]
return super::neon::vec_dot_q4k_q8k(n, xs, ys);
diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs
index 310d2462..1c378e5e 100644
--- a/candle-core/tests/conv_tests.rs
+++ b/candle-core/tests/conv_tests.rs
@@ -1,5 +1,5 @@
use anyhow::Result;
-use candle_core::{test_device, test_utils, Device, Tensor};
+use candle_core::{test_device, test_utils, Device, IndexOp, Tensor};
/* This test is based on the following script.
import torch
@@ -76,6 +76,11 @@ print(t.flatten())
print(w.flatten())
res = torch.nn.functional.conv2d(t, w)
print(res.flatten())
+
+w_t = w.transpose(0, 1)
+res = torch.nn.functional.conv_transpose2d(t, w_t)
+print(res.shape)
+print(res)
*/
fn conv2d(dev: &Device) -> Result<()> {
let t = Tensor::new(
@@ -117,6 +122,31 @@ fn conv2d(dev: &Device) -> Result<()> {
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
]
);
+ let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1)?;
+ assert_eq!(res.dims(), [1, 2, 7, 7]);
+ assert_eq!(
+ test_utils::to_vec3_round(&res.i(0)?, 4)?,
+ [
+ [
+ [-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277],
+ [1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375],
+ [0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889],
+ [0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632],
+ [-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985],
+ [2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114],
+ [5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579]
+ ],
+ [
+ [1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211],
+ [-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131],
+ [1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621],
+ [-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142],
+ [7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059],
+ [-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516],
+ [-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171]
+ ]
+ ]
+ );
Ok(())
}
@@ -170,26 +200,23 @@ fn conv2d_small(dev: &Device) -> Result<()> {
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
]
);
- // TODO: enable the test for cuda once we have the proper implementation in place.
- if dev.is_cpu() {
- let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1)?;
- assert_eq!(res.dims(), [1, 1, 3, 3]);
- assert_eq!(
- test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
- [0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539],
- );
- let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1)?;
- assert_eq!(res.dims(), [2, 2, 3, 3]);
- assert_eq!(
- test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
- [
- -0.3755, 0.8045, -0.6336, -0.2218, -1.1369, 0.8599, 1.5768, -0.1268, -0.1728,
- 0.528, -1.131, 0.8908, 0.3118, 1.5984, -1.2089, -2.2168, 0.1783, 0.2429, -0.3838,
- 0.5802, -0.3268, -2.0382, 0.6329, -0.2293, -1.2154, 0.6441, -0.3035, 0.5396,
- -0.8156, 0.4594, 2.8654, -0.8898, 0.3224, 1.7087, -0.9056, 0.4267
- ]
- );
- }
+ let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1)?;
+ assert_eq!(res.dims(), [1, 1, 3, 3]);
+ assert_eq!(
+ test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
+ [0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539],
+ );
+ let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1)?;
+ assert_eq!(res.dims(), [2, 2, 3, 3]);
+ assert_eq!(
+ test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
+ [
+ -0.3755, 0.8045, -0.6336, -0.2218, -1.1369, 0.8599, 1.5768, -0.1268, -0.1728, 0.528,
+ -1.131, 0.8908, 0.3118, 1.5984, -1.2089, -2.2168, 0.1783, 0.2429, -0.3838, 0.5802,
+ -0.3268, -2.0382, 0.6329, -0.2293, -1.2154, 0.6441, -0.3035, 0.5396, -0.8156, 0.4594,
+ 2.8654, -0.8898, 0.3224, 1.7087, -0.9056, 0.4267
+ ]
+ );
Ok(())
}
@@ -243,6 +270,74 @@ fn conv2d_non_square(dev: &Device) -> Result<()> {
Ok(())
}
+fn conv2d_grad(dev: &Device) -> Result<()> {
+ use candle_core::Var;
+ let t = Var::from_slice(
+ &[
+ 0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
+ 1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395,
+ 1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836,
+ 0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130, 1.3123,
+ 1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071, 1.1586,
+ 0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090, 0.2049,
+ 0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
+ 0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
+ -0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
+ -0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
+ ],
+ (1, 4, 5, 5),
+ dev,
+ )?;
+ let w = Var::from_slice(
+ &[
+ -0.9325f32, 0.6451, -0.8537, 0.2378, 0.8764, -0.1832, 0.2987, -0.6488, -0.2273,
+ -2.4184, -0.1192, -0.4821, -0.5079, -0.5766, -2.4729, 1.6734, 0.4558, 0.2851, 1.1514,
+ -0.9013, 1.0662, -0.1817, -0.0259, 0.1709, 0.5367, 0.7513, 0.8086, -2.2586, -0.5027,
+ 0.9141, -1.3086, -1.3343, -1.5669, -0.1657, 0.7958, 0.1432, 0.3896, -0.4501, 0.1667,
+ 0.0714, -0.0952, 1.2970, -0.1674, -0.3178, 1.0677, 0.3060, 0.7080, 0.1914, 1.1679,
+ -0.3602, 1.9265, -1.8626, -0.5112, -0.0982, 0.2621, 0.6565, 0.5908, 1.0089, -0.1646,
+ 1.8032, -0.6286, 0.2016, -0.3370, 1.2555, 0.8009, -0.6488, -0.4652, -1.5685, 1.5860,
+ 0.5583, 0.4623, 0.6026,
+ ],
+ (2, 4, 3, 3),
+ dev,
+ )?;
+ let res = t.conv2d(&w, 0, 1, 1)?;
+ let loss = res.sqr()?.sum_all()?;
+ assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 741.12f32);
+ let grads = loss.backward()?;
+ let grad_t = grads.get(&t).unwrap();
+ let grad_w = grads.get(&w).unwrap();
+ assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
+ assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
+ assert_eq!(
+ test_utils::to_vec1_round(&grad_t.flatten_all()?, 2)?,
+ [
+ 9.29, -2.84, -5.71, 3.38, -7.71, -19.15, 7.02, 29.1, 9.34, 34.73, -22.87, 24.35,
+ -39.88, -14.01, 21.08, 9.94, 13.63, -34.68, 11.21, -6.26, 7.72, -6.32, -16.64, -1.08,
+ -20.22, 21.73, -0.37, -4.06, 5.82, -3.65, -30.73, 14.55, 87.7, 31.6, 4.53, -89.78,
+ -75.37, -57.43, -7.56, 92.96, 18.79, -4.63, -159.75, -42.47, -47.26, 52.88, 37.32,
+ 49.0, 12.82, 2.01, -8.98, 20.18, 16.62, 12.06, 15.38, 20.0, 2.57, -15.22, 72.62,
+ -10.75, 2.25, -31.2, 3.75, -0.2, 9.76, -0.68, 5.21, -40.44, -22.59, -61.61, 17.28,
+ 20.41, 37.55, 5.23, 6.81, 23.54, 23.62, -9.99, -9.13, 4.87, -35.06, -26.1, 63.48,
+ 25.81, -39.21, -70.68, -46.96, 2.33, 41.81, 82.42, -28.63, -11.78, -35.33, -10.28,
+ -28.57, -9.13, 7.21, -9.05, -9.62, -11.25
+ ]
+ );
+ assert_eq!(
+ test_utils::to_vec1_round(&grad_w.flatten_all()?, 2)?,
+ [
+ -28.92, -22.88, -141.23, 73.35, 61.07, 47.81, -20.0, -73.71, -41.82, -13.59, 21.5,
+ 28.72, 28.57, -46.85, -90.19, 143.61, 16.68, 7.43, 18.88, -90.81, -20.29, 54.79, 82.63,
+ 22.94, 77.81, -16.39, -13.2, 9.34, -40.39, -26.62, 5.33, -60.91, 9.09, -59.37, 7.08,
+ 58.64, 5.55, 20.52, 2.5, -17.25, -6.8, 22.21, 30.15, -7.52, -37.46, 5.67, 22.58, 9.03,
+ 47.05, 17.61, 37.31, -98.13, -14.61, -4.8, -6.36, 44.69, 23.34, 8.37, -13.52, 80.05,
+ -34.24, -16.36, -12.31, 1.92, -33.62, -14.1, -49.23, -7.39, 11.5, -9.98, 9.66, 29.6
+ ]
+ );
+ Ok(())
+}
+
test_device!(conv1d, conv1d_cpu, conv1d_gpu);
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
@@ -253,3 +348,4 @@ test_device!(
);
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);
+test_device!(conv2d_grad, conv2d_grad_cpu, conv2d_grad_gpu);
diff --git a/candle-examples/examples/stable-diffusion/attention.rs b/candle-examples/examples/stable-diffusion/attention.rs
index 58f5e87e..1ae1bfc3 100644
--- a/candle-examples/examples/stable-diffusion/attention.rs
+++ b/candle-examples/examples/stable-diffusion/attention.rs
@@ -208,9 +208,9 @@ impl CrossAttention {
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let query = self.to_q.forward(xs)?;
- let context = context.unwrap_or(xs);
- let key = self.to_k.forward(context)?;
- let value = self.to_v.forward(context)?;
+ let context = context.unwrap_or(xs).contiguous()?;
+ let key = self.to_k.forward(&context)?;
+ let value = self.to_v.forward(&context)?;
let query = self.reshape_heads_to_batch_dim(&query)?;
let key = self.reshape_heads_to_batch_dim(&key)?;
let value = self.reshape_heads_to_batch_dim(&value)?;
@@ -473,7 +473,7 @@ impl AttentionBlock {
let num_heads = channels / num_head_channels;
let group_norm =
nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?;
- let (q_path, k_path, v_path, out_path) = if vs.dtype() == DType::F16 {
+ let (q_path, k_path, v_path, out_path) = if vs.contains_tensor("to_q.weight") {
("to_q", "to_k", "to_v", "to_out.0")
} else {
("query", "key", "value", "proj_attn")
diff --git a/candle-examples/examples/stable-diffusion/clip.rs b/candle-examples/examples/stable-diffusion/clip.rs
index 2927a404..d26c1c46 100644
--- a/candle-examples/examples/stable-diffusion/clip.rs
+++ b/candle-examples/examples/stable-diffusion/clip.rs
@@ -69,6 +69,36 @@ impl Config {
activation: Activation::Gelu,
}
}
+
+ // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/text_encoder/config.json
+ pub fn sdxl() -> Self {
+ Self {
+ vocab_size: 49408,
+ embed_dim: 768,
+ intermediate_size: 3072,
+ max_position_embeddings: 77,
+ pad_with: Some("!".to_string()),
+ num_hidden_layers: 12,
+ num_attention_heads: 12,
+ projection_dim: 768,
+ activation: Activation::QuickGelu,
+ }
+ }
+
+ // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/text_encoder_2/config.json
+ pub fn sdxl2() -> Self {
+ Self {
+ vocab_size: 49408,
+ embed_dim: 1280,
+ intermediate_size: 5120,
+ max_position_embeddings: 77,
+ pad_with: Some("!".to_string()),
+ num_hidden_layers: 32,
+ num_attention_heads: 20,
+ projection_dim: 1280,
+ activation: Activation::Gelu,
+ }
+ }
}
// CLIP Text Model
diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs
index 1443986c..8372edcd 100644
--- a/candle-examples/examples/stable-diffusion/main.rs
+++ b/candle-examples/examples/stable-diffusion/main.rs
@@ -17,7 +17,7 @@ mod utils;
mod vae;
use anyhow::{Error as E, Result};
-use candle::{DType, Device, IndexOp, Tensor};
+use candle::{DType, Device, IndexOp, Tensor, D};
use clap::Parser;
use tokenizers::Tokenizer;
@@ -102,12 +102,16 @@ struct Args {
enum StableDiffusionVersion {
V1_5,
V2_1,
+ Xl,
}
+#[allow(unused)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ModelFile {
Tokenizer,
+ Tokenizer2,
Clip,
+ Clip2,
Unet,
Vae,
}
@@ -115,6 +119,7 @@ enum ModelFile {
impl StableDiffusionVersion {
fn repo(&self) -> &'static str {
match self {
+ Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0",
Self::V2_1 => "stabilityai/stable-diffusion-2-1",
Self::V1_5 => "runwayml/stable-diffusion-v1-5",
}
@@ -122,7 +127,7 @@ impl StableDiffusionVersion {
fn unet_file(&self, use_f16: bool) -> &'static str {
match self {
- Self::V1_5 | Self::V2_1 => {
+ Self::V1_5 | Self::V2_1 | Self::Xl => {
if use_f16 {
"unet/diffusion_pytorch_model.fp16.safetensors"
} else {
@@ -134,7 +139,7 @@ impl StableDiffusionVersion {
fn vae_file(&self, use_f16: bool) -> &'static str {
match self {
- Self::V1_5 | Self::V2_1 => {
+ Self::V1_5 | Self::V2_1 | Self::Xl => {
if use_f16 {
"vae/diffusion_pytorch_model.fp16.safetensors"
} else {
@@ -146,7 +151,7 @@ impl StableDiffusionVersion {
fn clip_file(&self, use_f16: bool) -> &'static str {
match self {
- Self::V1_5 | Self::V2_1 => {
+ Self::V1_5 | Self::V2_1 | Self::Xl => {
if use_f16 {
"text_encoder/model.fp16.safetensors"
} else {
@@ -155,12 +160,21 @@ impl StableDiffusionVersion {
}
}
}
+
+ fn clip2_file(&self, use_f16: bool) -> &'static str {
+ match self {
+ Self::V1_5 | Self::V2_1 | Self::Xl => {
+ if use_f16 {
+ "text_encoder_2/model.fp16.safetensors"
+ } else {
+ "text_encoder_2/model.safetensors"
+ }
+ }
+ }
+ }
}
impl ModelFile {
- const TOKENIZER_REPO: &str = "openai/clip-vit-base-patch32";
- const TOKENIZER_PATH: &str = "tokenizer.json";
-
fn get(
&self,
filename: Option<String>,
@@ -172,8 +186,24 @@ impl ModelFile {
Some(filename) => Ok(std::path::PathBuf::from(filename)),
None => {
let (repo, path) = match self {
- Self::Tokenizer => (Self::TOKENIZER_REPO, Self::TOKENIZER_PATH),
+ Self::Tokenizer => {
+ let tokenizer_repo = match version {
+ StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => {
+ "openai/clip-vit-base-patch32"
+ }
+ StableDiffusionVersion::Xl => {
+ // This seems similar to the patch32 version except some very small
+ // difference in the split regex.
+ "openai/clip-vit-large-patch14"
+ }
+ };
+ (tokenizer_repo, "tokenizer.json")
+ }
+ Self::Tokenizer2 => {
+ ("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", "tokenizer.json")
+ }
Self::Clip => (version.repo(), version.clip_file(use_f16)),
+ Self::Clip2 => (version.repo(), version.clip2_file(use_f16)),
Self::Unet => (version.repo(), version.unet_file(use_f16)),
Self::Vae => (version.repo(), version.vae_file(use_f16)),
};
@@ -211,6 +241,71 @@ fn output_filename(
}
}
+#[allow(clippy::too_many_arguments)]
+fn text_embeddings(
+ prompt: &str,
+ uncond_prompt: &str,
+ tokenizer: Option<String>,
+ clip_weights: Option<String>,
+ sd_version: StableDiffusionVersion,
+ sd_config: &stable_diffusion::StableDiffusionConfig,
+ use_f16: bool,
+ device: &Device,
+ dtype: DType,
+ first: bool,
+) -> Result<Tensor> {
+ let tokenizer_file = if first {
+ ModelFile::Tokenizer
+ } else {
+ ModelFile::Tokenizer2
+ };
+ let tokenizer = tokenizer_file.get(tokenizer, sd_version, use_f16)?;
+ let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
+ let pad_id = match &sd_config.clip.pad_with {
+ Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
+ None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
+ };
+ println!("Running with prompt \"{prompt}\".");
+ let mut tokens = tokenizer
+ .encode(prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ while tokens.len() < sd_config.clip.max_position_embeddings {
+ tokens.push(pad_id)
+ }
+ let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
+
+ let mut uncond_tokens = tokenizer
+ .encode(uncond_prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
+ uncond_tokens.push(pad_id)
+ }
+ let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
+
+ println!("Building the Clip transformer.");
+ let clip_weights_file = if first {
+ ModelFile::Clip
+ } else {
+ ModelFile::Clip2
+ };
+ let clip_weights = clip_weights_file.get(clip_weights, sd_version, false)?;
+ let clip_config = if first {
+ &sd_config.clip
+ } else {
+ sd_config.clip2.as_ref().unwrap()
+ };
+ let text_model =
+ stable_diffusion::build_clip_transformer(clip_config, clip_weights, device, DType::F32)?;
+ let text_embeddings = text_model.forward(&tokens)?;
+ let uncond_embeddings = text_model.forward(&uncond_tokens)?;
+ let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?;
+ Ok(text_embeddings)
+}
+
fn run(args: Args) -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
@@ -252,46 +347,37 @@ fn run(args: Args) -> Result<()> {
StableDiffusionVersion::V2_1 => {
stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size, height, width)
}
+ StableDiffusionVersion::Xl => {
+ stable_diffusion::StableDiffusionConfig::sdxl(sliced_attention_size, height, width)
+ }
};
let scheduler = sd_config.build_scheduler(n_steps)?;
let device = candle_examples::device(cpu)?;
- let tokenizer = ModelFile::Tokenizer.get(tokenizer, sd_version, use_f16)?;
- let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
- let pad_id = match &sd_config.clip.pad_with {
- Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
- None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
- };
- println!("Running with prompt \"{prompt}\".");
- let mut tokens = tokenizer
- .encode(prompt, true)
- .map_err(E::msg)?
- .get_ids()
- .to_vec();
- while tokens.len() < sd_config.clip.max_position_embeddings {
- tokens.push(pad_id)
- }
- let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
-
- let mut uncond_tokens = tokenizer
- .encode(uncond_prompt, true)
- .map_err(E::msg)?
- .get_ids()
- .to_vec();
- while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
- uncond_tokens.push(pad_id)
- }
- let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?;
-
- println!("Building the Clip transformer.");
- let text_embeddings = {
- let clip_weights = ModelFile::Clip.get(clip_weights, sd_version, false)?;
- let text_model = sd_config.build_clip_transformer(&clip_weights, &device, DType::F32)?;
- let text_embeddings = text_model.forward(&tokens)?;
- let uncond_embeddings = text_model.forward(&uncond_tokens)?;
- Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?
+ let which = match sd_version {
+ StableDiffusionVersion::Xl => vec![true, false],
+ _ => vec![true],
};
+ let text_embeddings = which
+ .iter()
+ .map(|first| {
+ text_embeddings(
+ &prompt,
+ &uncond_prompt,
+ tokenizer.clone(),
+ clip_weights.clone(),
+ sd_version,
+ &sd_config,
+ use_f16,
+ &device,
+ dtype,
+ *first,
+ )
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
+ println!("{text_embeddings:?}");
println!("Building the autoencoder.");
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;
diff --git a/candle-examples/examples/stable-diffusion/stable_diffusion.rs b/candle-examples/examples/stable-diffusion/stable_diffusion.rs
index e159fa0a..cffc00d8 100644
--- a/candle-examples/examples/stable-diffusion/stable_diffusion.rs
+++ b/candle-examples/examples/stable-diffusion/stable_diffusion.rs
@@ -8,6 +8,7 @@ pub struct StableDiffusionConfig {
pub width: usize,
pub height: usize,
pub clip: clip::Config,
+ pub clip2: Option<clip::Config>,
autoencoder: vae::AutoEncoderKLConfig,
unet: unet_2d::UNet2DConditionModelConfig,
scheduler: ddim::DDIMSchedulerConfig,
@@ -27,10 +28,10 @@ impl StableDiffusionConfig {
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![
- bc(320, true, 8),
- bc(640, true, 8),
- bc(1280, true, 8),
- bc(1280, false, 8),
+ bc(320, Some(1), 8),
+ bc(640, Some(1), 8),
+ bc(1280, Some(1), 8),
+ bc(1280, None, 8),
],
center_input_sample: false,
cross_attention_dim: 768,
@@ -51,7 +52,7 @@ impl StableDiffusionConfig {
norm_num_groups: 32,
};
let height = if let Some(height) = height {
- assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
+ assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
512
@@ -68,6 +69,7 @@ impl StableDiffusionConfig {
width,
height,
clip: clip::Config::v1_5(),
+ clip2: None,
autoencoder,
scheduler: Default::default(),
unet,
@@ -88,10 +90,10 @@ impl StableDiffusionConfig {
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![
- bc(320, true, 5),
- bc(640, true, 10),
- bc(1280, true, 20),
- bc(1280, false, 20),
+ bc(320, Some(1), 5),
+ bc(640, Some(1), 10),
+ bc(1280, Some(1), 20),
+ bc(1280, None, 20),
],
center_input_sample: false,
cross_attention_dim: 1024,
@@ -118,7 +120,7 @@ impl StableDiffusionConfig {
};
let height = if let Some(height) = height {
- assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
+ assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
768
@@ -135,6 +137,7 @@ impl StableDiffusionConfig {
width,
height,
clip: clip::Config::v2_1(),
+ clip2: None,
autoencoder,
scheduler,
unet,
@@ -155,6 +158,87 @@ impl StableDiffusionConfig {
)
}
+ fn sdxl_(
+ sliced_attention_size: Option<usize>,
+ height: Option<usize>,
+ width: Option<usize>,
+ prediction_type: PredictionType,
+ ) -> Self {
+ let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
+ out_channels,
+ use_cross_attn,
+ attention_head_dim,
+ };
+ // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json
+ let unet = unet_2d::UNet2DConditionModelConfig {
+ blocks: vec![
+ bc(320, None, 5),
+ bc(640, Some(2), 10),
+ bc(1280, Some(10), 20),
+ ],
+ center_input_sample: false,
+ cross_attention_dim: 2048,
+ downsample_padding: 1,
+ flip_sin_to_cos: true,
+ freq_shift: 0.,
+ layers_per_block: 2,
+ mid_block_scale_factor: 1.,
+ norm_eps: 1e-5,
+ norm_num_groups: 32,
+ sliced_attention_size,
+ use_linear_projection: true,
+ };
+ // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/vae/config.json
+ let autoencoder = vae::AutoEncoderKLConfig {
+ block_out_channels: vec![128, 256, 512, 512],
+ layers_per_block: 2,
+ latent_channels: 4,
+ norm_num_groups: 32,
+ };
+ let scheduler = ddim::DDIMSchedulerConfig {
+ prediction_type,
+ ..Default::default()
+ };
+
+ let height = if let Some(height) = height {
+ assert_eq!(height % 8, 0, "height has to be divisible by 8");
+ height
+ } else {
+ 1024
+ };
+
+ let width = if let Some(width) = width {
+ assert_eq!(width % 8, 0, "width has to be divisible by 8");
+ width
+ } else {
+ 1024
+ };
+
+ Self {
+ width,
+ height,
+ clip: clip::Config::sdxl(),
+ clip2: Some(clip::Config::sdxl2()),
+ autoencoder,
+ scheduler,
+ unet,
+ }
+ }
+
+ pub fn sdxl(
+ sliced_attention_size: Option<usize>,
+ height: Option<usize>,
+ width: Option<usize>,
+ ) -> Self {
+ Self::sdxl_(
+ sliced_attention_size,
+ height,
+ width,
+ // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/scheduler/scheduler_config.json
+ PredictionType::Epsilon,
+ )
+ }
+
pub fn build_vae<P: AsRef<std::path::Path>>(
&self,
vae_weights: P,
@@ -193,17 +277,17 @@ impl StableDiffusionConfig {
pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> {
ddim::DDIMScheduler::new(n_steps, self.scheduler)
}
+}
- pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
- &self,
- clip_weights: P,
- device: &Device,
- dtype: DType,
- ) -> Result<clip::ClipTextTransformer> {
- let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
- let weights = weights.deserialize()?;
- let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
- let text_model = clip::ClipTextTransformer::new(vs, &self.clip)?;
- Ok(text_model)
- }
+pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
+ clip: &clip::Config,
+ clip_weights: P,
+ device: &Device,
+ dtype: DType,
+) -> Result<clip::ClipTextTransformer> {
+ let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
+ let weights = weights.deserialize()?;
+ let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
+ let text_model = clip::ClipTextTransformer::new(vs, clip)?;
+ Ok(text_model)
}
diff --git a/candle-examples/examples/stable-diffusion/unet_2d.rs b/candle-examples/examples/stable-diffusion/unet_2d.rs
index eb2dbf10..81bd9547 100644
--- a/candle-examples/examples/stable-diffusion/unet_2d.rs
+++ b/candle-examples/examples/stable-diffusion/unet_2d.rs
@@ -12,7 +12,9 @@ use candle_nn::Module;
#[derive(Debug, Clone, Copy)]
pub struct BlockConfig {
pub out_channels: usize,
- pub use_cross_attn: bool,
+ /// When `None` no cross-attn is used, when `Some(d)` then cross-attn is used and `d` is the
+ /// number of transformer blocks to be used.
+ pub use_cross_attn: Option<usize>,
pub attention_head_dim: usize,
}
@@ -41,22 +43,22 @@ impl Default for UNet2DConditionModelConfig {
blocks: vec![
BlockConfig {
out_channels: 320,
- use_cross_attn: true,
+ use_cross_attn: Some(1),
attention_head_dim: 8,
},
BlockConfig {
out_channels: 640,
- use_cross_attn: true,
+ use_cross_attn: Some(1),
attention_head_dim: 8,
},
BlockConfig {
out_channels: 1280,
- use_cross_attn: true,
+ use_cross_attn: Some(1),
attention_head_dim: 8,
},
BlockConfig {
out_channels: 1280,
- use_cross_attn: false,
+ use_cross_attn: None,
attention_head_dim: 8,
},
],
@@ -149,13 +151,14 @@ impl UNet2DConditionModel {
downsample_padding: config.downsample_padding,
..Default::default()
};
- if use_cross_attn {
+ if let Some(transformer_layers_per_block) = use_cross_attn {
let config = CrossAttnDownBlock2DConfig {
downblock: db_cfg,
attn_num_head_channels: attention_head_dim,
cross_attention_dim: config.cross_attention_dim,
sliced_attention_size,
use_linear_projection: config.use_linear_projection,
+ transformer_layers_per_block,
};
let block = CrossAttnDownBlock2D::new(
vs_db.pp(&i.to_string()),
@@ -179,6 +182,11 @@ impl UNet2DConditionModel {
})
.collect::<Result<Vec<_>>>()?;
+ // https://github.com/huggingface/diffusers/blob/a76f2ad538e73b34d5fe7be08c8eb8ab38c7e90c/src/diffusers/models/unet_2d_condition.py#L462
+ let mid_transformer_layers_per_block = match config.blocks.last() {
+ None => 1,
+ Some(block) => block.use_cross_attn.unwrap_or(1),
+ };
let mid_cfg = UNetMidBlock2DCrossAttnConfig {
resnet_eps: config.norm_eps,
output_scale_factor: config.mid_block_scale_factor,
@@ -186,8 +194,10 @@ impl UNet2DConditionModel {
attn_num_head_channels: bl_attention_head_dim,
resnet_groups: Some(config.norm_num_groups),
use_linear_projection: config.use_linear_projection,
+ transformer_layers_per_block: mid_transformer_layers_per_block,
..Default::default()
};
+
let mid_block = UNetMidBlock2DCrossAttn::new(
vs.pp("mid_block"),
bl_channels,
@@ -231,13 +241,14 @@ impl UNet2DConditionModel {
add_upsample: i < n_blocks - 1,
..Default::default()
};
- if use_cross_attn {
+ if let Some(transformer_layers_per_block) = use_cross_attn {
let config = CrossAttnUpBlock2DConfig {
upblock: ub_cfg,
attn_num_head_channels: attention_head_dim,
cross_attention_dim: config.cross_attention_dim,
sliced_attention_size,
use_linear_projection: config.use_linear_projection,
+ transformer_layers_per_block,
};
let block = CrossAttnUpBlock2D::new(
vs_ub.pp(&i.to_string()),
diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
index 65341e74..1db65222 100644
--- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
+++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
@@ -366,6 +366,7 @@ pub struct UNetMidBlock2DCrossAttnConfig {
pub cross_attn_dim: usize,
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
+ pub transformer_layers_per_block: usize,
}
impl Default for UNetMidBlock2DCrossAttnConfig {
@@ -379,6 +380,7 @@ impl Default for UNetMidBlock2DCrossAttnConfig {
cross_attn_dim: 1280,
sliced_attention_size: None, // Sliced attention disabled
use_linear_projection: false,
+ transformer_layers_per_block: 1,
}
}
}
@@ -414,7 +416,7 @@ impl UNetMidBlock2DCrossAttn {
let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
let n_heads = config.attn_num_head_channels;
let attn_cfg = SpatialTransformerConfig {
- depth: 1,
+ depth: config.transformer_layers_per_block,
num_groups: resnet_groups,
context_dim: Some(config.cross_attn_dim),
sliced_attention_size: config.sliced_attention_size,
@@ -565,6 +567,7 @@ pub struct CrossAttnDownBlock2DConfig {
// attention_type: "default"
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
+ pub transformer_layers_per_block: usize,
}
impl Default for CrossAttnDownBlock2DConfig {
@@ -575,6 +578,7 @@ impl Default for CrossAttnDownBlock2DConfig {
cross_attention_dim: 1280,
sliced_attention_size: None,
use_linear_projection: false,
+ transformer_layers_per_block: 1,
}
}
}
@@ -605,7 +609,7 @@ impl CrossAttnDownBlock2D {
)?;
let n_heads = config.attn_num_head_channels;
let cfg = SpatialTransformerConfig {
- depth: 1,
+ depth: config.transformer_layers_per_block,
context_dim: Some(config.cross_attention_dim),
num_groups: config.downblock.resnet_groups,
sliced_attention_size: config.sliced_attention_size,
@@ -767,6 +771,7 @@ pub struct CrossAttnUpBlock2DConfig {
// attention_type: "default"
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
+ pub transformer_layers_per_block: usize,
}
impl Default for CrossAttnUpBlock2DConfig {
@@ -777,6 +782,7 @@ impl Default for CrossAttnUpBlock2DConfig {
cross_attention_dim: 1280,
sliced_attention_size: None,
use_linear_projection: false,
+ transformer_layers_per_block: 1,
}
}
}
@@ -809,7 +815,7 @@ impl CrossAttnUpBlock2D {
)?;
let n_heads = config.attn_num_head_channels;
let cfg = SpatialTransformerConfig {
- depth: 1,
+ depth: config.transformer_layers_per_block,
context_dim: Some(config.cross_attention_dim),
num_groups: config.upblock.resnet_groups,
sliced_attention_size: config.sliced_attention_size,
diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu
index 19d94385..5ccce317 100644
--- a/candle-kernels/src/conv.cu
+++ b/candle-kernels/src/conv.cu
@@ -111,6 +111,71 @@ __device__ void conv2d(
dst[dst_i] = static_cast<T>(d);
}
+// Naive implementation of conv_transpose2d.
+template <typename T, typename A>
+__device__ void conv_transpose2d(
+ const size_t src_numel,
+ const size_t w_out,
+ const size_t h_out,
+ const size_t stride,
+ const size_t padding,
+ const size_t out_padding,
+ const size_t *info,
+ const T *src,
+ const T *kernel,
+ T *dst
+) {
+ const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
+ // src: (b_size, c_in, h_in, w_in)
+ // k: (c_in, c_out, h_k, w_k)
+ const size_t *src_dims = info;
+ const size_t *src_s = info + 4;
+ const size_t *k_dims = info + 8;
+ const size_t *k_s = info + 12;
+ const size_t h_k = k_dims[2];
+ const size_t w_k = k_dims[3];
+ const size_t c_out = k_dims[1];
+ const size_t c_in = src_dims[1];
+ const size_t h_in = src_dims[2];
+ const size_t w_in = src_dims[3];
+ if (dst_i >= src_dims[0] * c_out * w_out * h_out) {
+ return;
+ }
+
+ // TODO
+ const size_t b_idx = dst_i / (w_out * h_out * c_out);
+ const size_t dst_c_idx = (dst_i / (w_out * h_out)) % c_out;
+ // NCHW layout.
+ const size_t out_y = (dst_i / w_out) % h_out;
+ const size_t out_x = dst_i % w_out;
+
+ const size_t src_idx0 = b_idx * src_s[0];
+ A d = 0;
+ for (int k_x = 0; k_x < (int)w_k; ++k_x) {
+ // let out_x = inp_x * p.stride + k_x - p.padding;
+ int inp_x_stride = (int)(out_x + padding) - k_x;
+ if (inp_x_stride < 0 || inp_x_stride % stride) {
+ continue;
+ }
+ int inp_x = inp_x_stride / stride;
+ if (inp_x >= w_in) continue;
+ for (int k_y = 0; k_y < (int)h_k; ++k_y) {
+ int inp_y_stride = (int)(out_y + padding) - k_y;
+ if (inp_y_stride < 0 || inp_y_stride % stride) {
+ continue;
+ }
+ int inp_y = inp_y_stride / stride;
+ if (inp_y >= h_in) continue;
+ for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
+ const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + inp_y * src_s[2] + inp_x * src_s[3];
+ const size_t k_idx = src_c_idx * k_s[0] + dst_c_idx * k_s[1] + k_y * k_s[2] + k_x * k_s[3];
+ d += static_cast<A>(src[src_idx]) * static_cast<A>(kernel[k_idx]);
+ }
+ }
+ }
+ dst[dst_i] = static_cast<T>(d);
+}
+
template <typename T, typename A>
__device__ void avg_pool2d(
const size_t src_numel,
@@ -293,6 +358,22 @@ extern "C" __global__ void FN_NAME( \
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, info, src, kernel, dst); \
} \
+#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
+extern "C" __global__ void FN_NAME( \
+ const size_t src_numel, \
+ const size_t w_out, \
+ const size_t h_out, \
+ const size_t stride, \
+ const size_t padding, \
+ const size_t out_padding, \
+ const size_t *info, \
+ const TYPENAME *src, \
+ const TYPENAME *kernel, \
+ TYPENAME *dst \
+) { \
+ conv_transpose2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, out_padding, info, src, kernel, dst); \
+} \
+
#define AVG_POOL2D_OP(TYPENAME, TYPEACC, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t src_numel, \
@@ -337,6 +418,7 @@ extern "C" __global__ void FN_NAME( \
#if __CUDA_ARCH__ >= 800
CONV1D_OP(__nv_bfloat16, float, conv1d_bf16)
CONV2D_OP(__nv_bfloat16, float, conv2d_bf16)
+CONVT2D_OP(__nv_bfloat16, float, conv_transpose2d_bf16)
AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16)
MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
@@ -345,6 +427,7 @@ UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
#if __CUDA_ARCH__ >= 530
CONV1D_OP(__half, float, conv1d_f16)
CONV2D_OP(__half, float, conv2d_f16)
+CONVT2D_OP(__half, float, conv_transpose2d_f16)
AVG_POOL2D_OP(__half, float, avg_pool2d_f16)
MAX_POOL2D_OP(__half, max_pool2d_f16)
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
@@ -360,6 +443,11 @@ CONV2D_OP(double, double, conv2d_f64)
CONV2D_OP(uint8_t, uint8_t, conv2d_u8)
CONV2D_OP(uint32_t, uint32_t, conv2d_u32)
+CONVT2D_OP(float, float, conv_transpose2d_f32)
+CONVT2D_OP(double, double, conv_transpose2d_f64)
+CONVT2D_OP(uint8_t, uint8_t, conv_transpose2d_u8)
+CONVT2D_OP(uint32_t, uint32_t, conv_transpose2d_u32)
+
AVG_POOL2D_OP(float, float, avg_pool2d_f32)
AVG_POOL2D_OP(double, double, avg_pool2d_f64)
AVG_POOL2D_OP(uint8_t, uint8_t, avg_pool2d_u8)
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index c593960b..c372897a 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -52,6 +52,8 @@ pub trait Backend {
dtype: DType,
dev: &Device,
) -> Result<Tensor>;
+
+ fn contains_tensor(&self, name: &str) -> bool;
}
pub trait SimpleBackend {
@@ -64,6 +66,8 @@ pub trait SimpleBackend {
dtype: DType,
dev: &Device,
) -> Result<Tensor>;
+
+ fn contains_tensor(&self, name: &str) -> bool;
}
impl<'a> Backend for Box<dyn SimpleBackend + 'a> {
@@ -78,6 +82,10 @@ impl<'a> Backend for Box<dyn SimpleBackend + 'a> {
) -> Result<Tensor> {
self.as_ref().get(s, name, h, dtype, dev)
}
+
+ fn contains_tensor(&self, name: &str) -> bool {
+ self.as_ref().contains_tensor(name)
+ }
}
impl<'a, B: Backend> VarBuilderArgs<'a, B> {
@@ -94,6 +102,8 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
}
}
+ /// Return a new `VarBuilder` adding `s` to the current prefix. This can be think of as `cd`
+ /// into a directory.
pub fn push_prefix<S: ToString>(&self, s: S) -> Self {
let mut path = self.path.clone();
path.push(s.to_string());
@@ -109,10 +119,12 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
self.push_prefix(s)
}
+ /// The device used by default.
pub fn device(&self) -> &Device {
&self.data.device
}
+ /// The dtype used by default.
pub fn dtype(&self) -> DType {
self.data.dtype
}
@@ -125,6 +137,14 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
}
}
+ /// This returns true only if a tensor with the passed in name is available. E.g. when passed
+ /// `a`, true is returned if `prefix.a` exists but false is returned if only `prefix.a.b`
+ /// exists.
+ pub fn contains_tensor(&self, tensor_name: &str) -> bool {
+ let path = self.path(tensor_name);
+ self.data.backend.contains_tensor(&path)
+ }
+
/// Retrieve the tensor associated with the given name at the current path.
pub fn get_with_hints<S: Into<Shape>>(
&self,
@@ -149,6 +169,10 @@ impl SimpleBackend for Zeros {
fn get(&self, s: Shape, _: &str, _: crate::Init, dtype: DType, dev: &Device) -> Result<Tensor> {
Tensor::zeros(s, dtype, dev)
}
+
+ fn contains_tensor(&self, _name: &str) -> bool {
+ true
+ }
}
impl SimpleBackend for HashMap<String, Tensor> {
@@ -179,6 +203,10 @@ impl SimpleBackend for HashMap<String, Tensor> {
}
tensor.to_device(dev)?.to_dtype(dtype)
}
+
+ fn contains_tensor(&self, name: &str) -> bool {
+ self.contains_key(name)
+ }
}
impl SimpleBackend for VarMap {
@@ -192,6 +220,10 @@ impl SimpleBackend for VarMap {
) -> Result<Tensor> {
VarMap::get(self, s, name, h, dtype, dev)
}
+
+ fn contains_tensor(&self, name: &str) -> bool {
+ self.data().lock().unwrap().contains_key(name)
+ }
}
struct SafeTensorWithRouting<'a> {
@@ -228,6 +260,10 @@ impl<'a> SimpleBackend for SafeTensorWithRouting<'a> {
}
Ok(tensor)
}
+
+ fn contains_tensor(&self, name: &str) -> bool {
+ self.routing.contains_key(name)
+ }
}
impl SimpleBackend for candle::npy::NpzTensors {
@@ -257,6 +293,10 @@ impl SimpleBackend for candle::npy::NpzTensors {
}
Ok(tensor)
}
+
+ fn contains_tensor(&self, name: &str) -> bool {
+ self.get(name).map_or(false, |v| v.is_some())
+ }
}
impl<'a> VarBuilder<'a> {
@@ -425,4 +465,8 @@ impl<'a> Backend for ShardedSafeTensors<'a> {
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype)
}
+
+ fn contains_tensor(&self, name: &str) -> bool {
+ self.0.routing.contains_key(name)
+ }
}