diff options
Diffstat (limited to 'candle-core')
-rw-r--r-- | candle-core/src/backprop.rs | 62 | ||||
-rw-r--r-- | candle-core/src/conv.rs | 12 | ||||
-rw-r--r-- | candle-core/src/cpu_backend.rs | 87 | ||||
-rw-r--r-- | candle-core/src/cuda_backend.rs | 66 | ||||
-rw-r--r-- | candle-core/src/display.rs | 2 | ||||
-rw-r--r-- | candle-core/src/quantized/avx.rs | 126 | ||||
-rw-r--r-- | candle-core/src/quantized/k_quants.rs | 3 | ||||
-rw-r--r-- | candle-core/tests/conv_tests.rs | 138 |
8 files changed, 422 insertions, 74 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 22c28ac4..9ecdee4f 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -192,12 +192,68 @@ impl Tensor { *f_sum_grad = f_sum_grad.add(&f_grad)?; } Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?, - Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?, + Op::Conv2D { + arg, + kernel, + padding, + stride, + } => { + // The output height for conv_transpose2d is: + // (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1 + let grad_h = grad.dim(2)?; + let k_h = kernel.dim(2)?; + let out_size = (grad_h - 1) * stride + (k_h - 1) + 1 - 2 * padding; + let out_padding = arg.dim(2)? - out_size; + let grad_arg = + grad.conv_transpose2d(kernel, *padding, out_padding, *stride)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&grad_arg)?; + + let grad_kernel = arg + .transpose(0, 1)? + .conv2d(&grad.transpose(0, 1)?, *padding, *stride, 1)? + .transpose(0, 1)?; + let sum_grad = grads.or_insert(kernel)?; + *sum_grad = sum_grad.add(&grad_kernel)?; + } Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported { op: "conv-transpose2d", })?, - Op::AvgPool2D { .. } => Err(Error::BackwardNotSupported { op: "avg-pool2d" })?, - Op::MaxPool2D { .. } => Err(Error::BackwardNotSupported { op: "max-pool2d" })?, + Op::AvgPool2D { + arg, + kernel_size, + stride, + } => { + if kernel_size != stride { + crate::bail!("backward not supported for avgpool2d if ksize {kernel_size:?} != stride {stride:?}") + } + let (_n, _c, h, w) = arg.dims4()?; + let grad_arg = grad.upsample_nearest2d(h, w)?; + let grad_arg = + (grad_arg * (1f64 / (kernel_size.0 * kernel_size.1) as f64))?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&grad_arg)?; + } + Op::MaxPool2D { + arg, + kernel_size, + stride, + } => { + if kernel_size != stride { + crate::bail!("backward not supported for maxpool2d if ksize {kernel_size:?} != stride {stride:?}") + } + let (_n, _c, h, w) = arg.dims4()?; + // For computing the max-pool gradient, we compute a mask where a 1 means + // that the element is the maximum, then we apply this mask to the + // upsampled gradient (taking into account that multiple max may exist so + // we scale the gradient for this case). + let node_upsampled = node.upsample_nearest2d(h, w)?; + let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?; + let avg = mask.avg_pool2d(*kernel_size, *stride)?; + let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&grad_arg)?; + } Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported { op: "upsample-nearest2d", })?, diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index 3455247b..d9e0a9ab 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -71,18 +71,14 @@ pub struct ParamsConvTranspose2D { impl ParamsConvTranspose2D { pub(crate) fn out_h(&self) -> usize { let dilation = 1; - (self.i_h - 1) * self.stride - 2 * self.padding - + dilation * (self.k_h - 1) - + self.output_padding - + 1 + (self.i_h - 1) * self.stride + dilation * (self.k_h - 1) + self.output_padding + 1 + - 2 * self.padding } pub(crate) fn out_w(&self) -> usize { let dilation = 1; - (self.i_w - 1) * self.stride - 2 * self.padding - + dilation * (self.k_w - 1) - + self.output_padding - + 1 + (self.i_w - 1) * self.stride + dilation * (self.k_w - 1) + self.output_padding + 1 + - 2 * self.padding } pub(crate) fn out_dims(&self) -> Vec<usize> { diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 0b19904b..f52d53b1 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1193,41 +1193,78 @@ impl<'a> Map2 for ConvTranspose2D<'a> { let (out_h, out_w) = (p.out_h(), p.out_w()); // Output shape: [b_size, c_out, out_h, out_w]. - let mut dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w]; + let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w]; let dst_s0 = p.c_out * out_h * out_w; let dst_s1 = out_h * out_w; let dst_s2 = out_w; let dst_s3 = 1; + + // TODO: Avoid making this copy if `inp` already has the appropriate layout. + let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w]; + let cont_s0 = p.i_h * p.i_w * p.c_in; + let cont_s1 = p.i_w * p.c_in; + let cont_s2 = p.c_in; for b_idx in 0..p.b_size { - for out_y in 0..out_h as i32 { - for out_x in 0..out_w as i32 { - let inp_x = out_x * p.stride as i32 - p.padding as i32; - let inp_y = out_y * p.stride as i32 - p.padding as i32; - for k_y in 0..p.k_h as i32 { - for k_x in 0..p.k_h as i32 { - let k_index = k_y as usize * k_s2 + k_x as usize * k_s3; - let inp_y = inp_y + k_y; - let inp_x = inp_x + k_x; - if inp_x < 0 || inp_y < 0 { - continue; - } - let inp_x = inp_x as usize; - let inp_y = inp_y as usize; - if inp_x < p.i_w && inp_y < p.i_h { - let inp_index = b_idx * inp_s0 + inp_y * inp_s2 + inp_x * inp_s3; - let dst_index = b_idx * dst_s0 + inp_y * dst_s2 + inp_x * dst_s3; - for c_out in 0..k_s0 { - for c_in in 0..k_s1 { - let k_index = k_index + c_out * k_s1 + c_in * k_s0; - let dst_index = dst_index + c_out * dst_s1; - let inp_index = inp_index + c_in * inp_s1; - dst[dst_index] += k[k_index] * inp[inp_index] + for h_idx in 0..p.i_h { + for w_idx in 0..p.i_w { + for c_idx in 0..p.c_in { + let src_idx = + b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3; + let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx; + inp_cont[dst_idx] = inp[src_idx] + } + } + } + } + let num_threads = crate::utils::get_num_threads(); + + for k_y in 0..p.k_h { + for k_x in 0..p.k_w { + crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| { + let k_cont = (0..p.c_in) + .map(|c_in_idx| { + k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3] + }) + .collect::<Vec<_>>(); + for b_idx in 0..p.b_size { + for inp_y in 0..p.i_h { + for inp_x in 0..p.i_w { + let out_x = inp_x * p.stride + k_x; + let out_y = inp_y * p.stride + k_y; + if out_x < p.padding || out_y < p.padding { + continue; + } + let out_x = out_x - p.padding; + let out_y = out_y - p.padding; + if out_x < out_w && out_y < out_h { + let inp_cont = &inp_cont + [b_idx * cont_s0 + inp_y * cont_s1 + inp_x * cont_s2..]; + let dst_idx = b_idx * dst_s0 + + out_y * dst_s2 + + out_x * dst_s3 + + dst_c_idx * dst_s1; + let mut d = T::zero(); + unsafe { + T::vec_dot( + inp_cont.as_ptr(), + k_cont.as_ptr(), + &mut d, + p.c_in, + ) + } + let dst_p = dst.as_ptr(); + // Safety: dst_idx are uniques per dst_c_idx which is used to + // parallelise the different tasks so no two threads can try to + // write at the same location. + unsafe { + let ptr = dst_p.add(dst_idx) as *mut T; + *ptr += d } } } } } - } + }) } } Ok(dst) diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 75eaf70a..ed696368 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -977,8 +977,8 @@ impl<'a> Map2 for Conv2D<'a> { k_l: &Layout, dev: &CudaDevice, ) -> Result<CudaSlice<T>> { - // Kernel shape: (c_out, c_in_k, w_k, h_k) - // Input shape: (b_size, c_in, w_in, c_in) + // Kernel shape: (c_out, c_in_k, h_k, w_k) + // Input shape: (b_size, c_in, h_in, w_in) let p = &self.0; let (out_w, out_h) = (p.out_w(), p.out_h()); let dst_el = p.c_out * out_w * out_h * p.b_size; @@ -1005,6 +1005,55 @@ impl<'a> Map2 for Conv2D<'a> { } } +struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D); +impl<'a> Map2 for ConvTranspose2D<'a> { + fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( + &self, + inp: &CudaSlice<T>, + inp_l: &Layout, + k: &CudaSlice<T>, + k_l: &Layout, + dev: &CudaDevice, + ) -> Result<CudaSlice<T>> { + // Kernel shape: (c_in_k, c_out, h_k, w_k) + // Input shape: (b_size, c_in, h_in, w_in) + let p = &self.0; + let (out_w, out_h) = (p.out_w(), p.out_h()); + let dst_el = p.c_out * out_w * out_h * p.b_size; + let inp = &inp.slice(inp_l.start_offset()..); + let k = &k.slice(k_l.start_offset()..); + let shape = inp_l.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::<T>(dst_el) }.w()?; + let cfg = LaunchConfig::for_num_elems(dst_el as u32); + let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose2d"), kernels::CONV)?; + let ds = if dims.len() == 4 { + [dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat() + } else { + crate::bail!("unexpected input shape for conv_transpose2d {dims:?}") + }; + let ds = dev.htod_copy(ds).w()?; + let params = ( + el, + out_w, + out_h, + p.stride, + p.padding, + p.output_padding, + &ds, + inp, + k, + &out, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(out) + } +} + enum PoolOp { Max, Avg, @@ -1649,12 +1698,15 @@ impl BackendStorage for CudaStorage { fn conv_transpose2d( &self, - _l: &Layout, - _kernel: &Self, - _kernel_l: &Layout, - _params: &crate::conv::ParamsConvTranspose2D, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &crate::conv::ParamsConvTranspose2D, ) -> Result<Self> { - todo!() + let device = self.device().clone(); + let slice = + ConvTranspose2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; + Ok(Self { slice, device }) } fn avg_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result<Self> { diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs index 8390a4a0..b497699b 100644 --- a/candle-core/src/display.rs +++ b/candle-core/src/display.rs @@ -43,7 +43,7 @@ impl Tensor { } } } - write!(f, "; {} ,{}]", self.dtype().as_str(), device_str) + write!(f, "; {}{}]", self.dtype().as_str(), device_str) } } diff --git a/candle-core/src/quantized/avx.rs b/candle-core/src/quantized/avx.rs index 96087feb..f906d090 100644 --- a/candle-core/src/quantized/avx.rs +++ b/candle-core/src/quantized/avx.rs @@ -1,5 +1,6 @@ -use super::k_quants::{BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K}; +use super::k_quants::{BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K}; use crate::Result; +use byteorder::{ByteOrder, LittleEndian}; use half::f16; #[cfg(target_arch = "x86")] @@ -89,17 +90,35 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> } } -const K_SHUFFLE: [u8; 128] = [ - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, - 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, - 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, - 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13, 13, 14, 14, 14, 14, - 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15, -]; - +#[inline(always)] unsafe fn get_scale_shuffle(i: usize) -> __m128i { + const K_SHUFFLE: [u8; 128] = [ + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, + 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, + 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, + 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13, + 13, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15, + ]; _mm_loadu_si128((K_SHUFFLE.as_ptr() as *const __m128i).add(i)) } + +#[inline(always)] +unsafe fn get_scale_shuffle_k4(i: usize) -> __m256i { + const K_SHUFFLE: [u8; 256] = [ + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 2, 3, 2, 3, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, + 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 6, 7, 6, 7, 6, 7, 6, 7, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, + 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 12, 13, 12, 13, 12, 13, + 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, + 13, 12, 13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, + 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, + ]; + _mm256_loadu_si256((K_SHUFFLE.as_ptr() as *const __m256i).add(i)) +} + #[inline(always)] pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> { let qk = QK_K; @@ -187,3 +206,92 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res Ok(hsum_float_8(acc)) } } + +#[inline(always)] +unsafe fn mm256_set_m128i(a: __m128i, b: __m128i) -> __m256i { + _mm256_insertf128_si256(_mm256_castsi128_si256(b), a, 1) +} + +#[inline(always)] +pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> { + if n % QK_K != 0 { + crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}") + } + let mut utmp = [0u32; 4]; + let kmask1: u32 = 0x3f3f3f3f; + let kmask2: u32 = 0x0f0f0f0f; + let kmask3: u32 = 0x03030303; + + unsafe { + let m4 = _mm256_set1_epi8(0xF); + + let mut acc = _mm256_setzero_ps(); + let mut acc_m = _mm_setzero_ps(); + + for (x, y) in xs.iter().zip(ys.iter()) { + let d = y.d * x.d.to_f32(); + let dmin = -y.d * x.dmin.to_f32(); + + LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]); + + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + let uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + let mut q4 = x.qs.as_ptr(); + let mut q8 = y.qs.as_ptr(); + + let mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32( + utmp[3] as i32, + utmp[2] as i32, + utmp[1] as i32, + utmp[0] as i32, + )); + + let q8sums = _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i); + let q8s = _mm_hadd_epi16( + _mm256_extracti128_si256(q8sums, 0), + _mm256_extracti128_si256(q8sums, 1), + ); + let prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); + acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m); + + let sc128 = _mm256_extracti128_si256(mins_and_scales, 0); + let scales = mm256_set_m128i(sc128, sc128); + + let mut sumi = _mm256_setzero_si256(); + + for j in 0..QK_K / 64 { + let scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j)); + let scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j + 1)); + + let q4bits = _mm256_loadu_si256(q4 as *const __m256i); + q4 = q4.add(32); + let q4l = _mm256_and_si256(q4bits, m4); + let q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); + + let q8l = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + let p16l = _mm256_maddubs_epi16(q4l, q8l); + let p16l = _mm256_madd_epi16(scale_l, p16l); + sumi = _mm256_add_epi32(sumi, p16l); + + let q8h = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + let p16h = _mm256_maddubs_epi16(q4h, q8h); + let p16h = _mm256_madd_epi16(scale_h, p16h); + sumi = _mm256_add_epi32(sumi, p16h); + } + + let vd = _mm256_set1_ps(d); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); + } + + let acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); + let acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); + + Ok(hsum_float_8(acc) + _mm_cvtss_f32(acc_m)) + } +} diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 7b405ec9..7f14600b 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -1104,6 +1104,9 @@ impl GgmlType for BlockQ4K { #[allow(unreachable_code)] fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> { + #[cfg(target_feature = "avx")] + return super::avx::vec_dot_q4k_q8k(n, xs, ys); + #[cfg(target_feature = "neon")] return super::neon::vec_dot_q4k_q8k(n, xs, ys); diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs index 310d2462..1c378e5e 100644 --- a/candle-core/tests/conv_tests.rs +++ b/candle-core/tests/conv_tests.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use candle_core::{test_device, test_utils, Device, Tensor}; +use candle_core::{test_device, test_utils, Device, IndexOp, Tensor}; /* This test is based on the following script. import torch @@ -76,6 +76,11 @@ print(t.flatten()) print(w.flatten()) res = torch.nn.functional.conv2d(t, w) print(res.flatten()) + +w_t = w.transpose(0, 1) +res = torch.nn.functional.conv_transpose2d(t, w_t) +print(res.shape) +print(res) */ fn conv2d(dev: &Device) -> Result<()> { let t = Tensor::new( @@ -117,6 +122,31 @@ fn conv2d(dev: &Device) -> Result<()> { 10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075 ] ); + let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1)?; + assert_eq!(res.dims(), [1, 2, 7, 7]); + assert_eq!( + test_utils::to_vec3_round(&res.i(0)?, 4)?, + [ + [ + [-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277], + [1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375], + [0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889], + [0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632], + [-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985], + [2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114], + [5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579] + ], + [ + [1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211], + [-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131], + [1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621], + [-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142], + [7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059], + [-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516], + [-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171] + ] + ] + ); Ok(()) } @@ -170,26 +200,23 @@ fn conv2d_small(dev: &Device) -> Result<()> { 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000 ] ); - // TODO: enable the test for cuda once we have the proper implementation in place. - if dev.is_cpu() { - let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1)?; - assert_eq!(res.dims(), [1, 1, 3, 3]); - assert_eq!( - test_utils::to_vec1_round(&res.flatten_all()?, 4)?, - [0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539], - ); - let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1)?; - assert_eq!(res.dims(), [2, 2, 3, 3]); - assert_eq!( - test_utils::to_vec1_round(&res.flatten_all()?, 4)?, - [ - -0.3755, 0.8045, -0.6336, -0.2218, -1.1369, 0.8599, 1.5768, -0.1268, -0.1728, - 0.528, -1.131, 0.8908, 0.3118, 1.5984, -1.2089, -2.2168, 0.1783, 0.2429, -0.3838, - 0.5802, -0.3268, -2.0382, 0.6329, -0.2293, -1.2154, 0.6441, -0.3035, 0.5396, - -0.8156, 0.4594, 2.8654, -0.8898, 0.3224, 1.7087, -0.9056, 0.4267 - ] - ); - } + let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1)?; + assert_eq!(res.dims(), [1, 1, 3, 3]); + assert_eq!( + test_utils::to_vec1_round(&res.flatten_all()?, 4)?, + [0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539], + ); + let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1)?; + assert_eq!(res.dims(), [2, 2, 3, 3]); + assert_eq!( + test_utils::to_vec1_round(&res.flatten_all()?, 4)?, + [ + -0.3755, 0.8045, -0.6336, -0.2218, -1.1369, 0.8599, 1.5768, -0.1268, -0.1728, 0.528, + -1.131, 0.8908, 0.3118, 1.5984, -1.2089, -2.2168, 0.1783, 0.2429, -0.3838, 0.5802, + -0.3268, -2.0382, 0.6329, -0.2293, -1.2154, 0.6441, -0.3035, 0.5396, -0.8156, 0.4594, + 2.8654, -0.8898, 0.3224, 1.7087, -0.9056, 0.4267 + ] + ); Ok(()) } @@ -243,6 +270,74 @@ fn conv2d_non_square(dev: &Device) -> Result<()> { Ok(()) } +fn conv2d_grad(dev: &Device) -> Result<()> { + use candle_core::Var; + let t = Var::from_slice( + &[ + 0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616, + 1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395, + 1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836, + 0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130, 1.3123, + 1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071, 1.1586, + 0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090, 0.2049, + 0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712, + 0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790, + -0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006, + -0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085, + ], + (1, 4, 5, 5), + dev, + )?; + let w = Var::from_slice( + &[ + -0.9325f32, 0.6451, -0.8537, 0.2378, 0.8764, -0.1832, 0.2987, -0.6488, -0.2273, + -2.4184, -0.1192, -0.4821, -0.5079, -0.5766, -2.4729, 1.6734, 0.4558, 0.2851, 1.1514, + -0.9013, 1.0662, -0.1817, -0.0259, 0.1709, 0.5367, 0.7513, 0.8086, -2.2586, -0.5027, + 0.9141, -1.3086, -1.3343, -1.5669, -0.1657, 0.7958, 0.1432, 0.3896, -0.4501, 0.1667, + 0.0714, -0.0952, 1.2970, -0.1674, -0.3178, 1.0677, 0.3060, 0.7080, 0.1914, 1.1679, + -0.3602, 1.9265, -1.8626, -0.5112, -0.0982, 0.2621, 0.6565, 0.5908, 1.0089, -0.1646, + 1.8032, -0.6286, 0.2016, -0.3370, 1.2555, 0.8009, -0.6488, -0.4652, -1.5685, 1.5860, + 0.5583, 0.4623, 0.6026, + ], + (2, 4, 3, 3), + dev, + )?; + let res = t.conv2d(&w, 0, 1, 1)?; + let loss = res.sqr()?.sum_all()?; + assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 741.12f32); + let grads = loss.backward()?; + let grad_t = grads.get(&t).unwrap(); + let grad_w = grads.get(&w).unwrap(); + assert_eq!(grad_t.dims(), [1, 4, 5, 5]); + assert_eq!(grad_w.dims(), [2, 4, 3, 3]); + assert_eq!( + test_utils::to_vec1_round(&grad_t.flatten_all()?, 2)?, + [ + 9.29, -2.84, -5.71, 3.38, -7.71, -19.15, 7.02, 29.1, 9.34, 34.73, -22.87, 24.35, + -39.88, -14.01, 21.08, 9.94, 13.63, -34.68, 11.21, -6.26, 7.72, -6.32, -16.64, -1.08, + -20.22, 21.73, -0.37, -4.06, 5.82, -3.65, -30.73, 14.55, 87.7, 31.6, 4.53, -89.78, + -75.37, -57.43, -7.56, 92.96, 18.79, -4.63, -159.75, -42.47, -47.26, 52.88, 37.32, + 49.0, 12.82, 2.01, -8.98, 20.18, 16.62, 12.06, 15.38, 20.0, 2.57, -15.22, 72.62, + -10.75, 2.25, -31.2, 3.75, -0.2, 9.76, -0.68, 5.21, -40.44, -22.59, -61.61, 17.28, + 20.41, 37.55, 5.23, 6.81, 23.54, 23.62, -9.99, -9.13, 4.87, -35.06, -26.1, 63.48, + 25.81, -39.21, -70.68, -46.96, 2.33, 41.81, 82.42, -28.63, -11.78, -35.33, -10.28, + -28.57, -9.13, 7.21, -9.05, -9.62, -11.25 + ] + ); + assert_eq!( + test_utils::to_vec1_round(&grad_w.flatten_all()?, 2)?, + [ + -28.92, -22.88, -141.23, 73.35, 61.07, 47.81, -20.0, -73.71, -41.82, -13.59, 21.5, + 28.72, 28.57, -46.85, -90.19, 143.61, 16.68, 7.43, 18.88, -90.81, -20.29, 54.79, 82.63, + 22.94, 77.81, -16.39, -13.2, 9.34, -40.39, -26.62, 5.33, -60.91, 9.09, -59.37, 7.08, + 58.64, 5.55, 20.52, 2.5, -17.25, -6.8, 22.21, 30.15, -7.52, -37.46, 5.67, 22.58, 9.03, + 47.05, 17.61, 37.31, -98.13, -14.61, -4.8, -6.36, 44.69, 23.34, 8.37, -13.52, 80.05, + -34.24, -16.36, -12.31, 1.92, -33.62, -14.1, -49.23, -7.39, 11.5, -9.98, 9.66, 29.6 + ] + ); + Ok(()) +} + test_device!(conv1d, conv1d_cpu, conv1d_gpu); test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu); test_device!(conv2d, conv2d_cpu, conv2d_gpu); @@ -253,3 +348,4 @@ test_device!( ); test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu); test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu); +test_device!(conv2d_grad, conv2d_grad_cpu, conv2d_grad_gpu); |