diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/backprop.rs | 15 | ||||
-rw-r--r-- | candle-core/src/conv.rs | 27 | ||||
-rw-r--r-- | candle-core/src/cpu_backend.rs | 12 | ||||
-rw-r--r-- | candle-core/src/cuda_backend.rs | 15 | ||||
-rw-r--r-- | candle-core/src/cudnn.rs | 2 | ||||
-rw-r--r-- | candle-core/src/device.rs | 20 | ||||
-rw-r--r-- | candle-core/src/op.rs | 3 |
7 files changed, 74 insertions, 20 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 9ecdee4f..f4f90373 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -197,21 +197,28 @@ impl Tensor { kernel, padding, stride, + dilation, } => { // The output height for conv_transpose2d is: // (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1 let grad_h = grad.dim(2)?; let k_h = kernel.dim(2)?; - let out_size = (grad_h - 1) * stride + (k_h - 1) + 1 - 2 * padding; + let out_size = + (grad_h - 1) * stride + dilation * (k_h - 1) + 1 - 2 * padding; let out_padding = arg.dim(2)? - out_size; - let grad_arg = - grad.conv_transpose2d(kernel, *padding, out_padding, *stride)?; + let grad_arg = grad.conv_transpose2d( + kernel, + *padding, + out_padding, + *stride, + *dilation, + )?; let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&grad_arg)?; let grad_kernel = arg .transpose(0, 1)? - .conv2d(&grad.transpose(0, 1)?, *padding, *stride, 1)? + .conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)? .transpose(0, 1)?; let sum_grad = grads.or_insert(kernel)?; *sum_grad = sum_grad.add(&grad_kernel)?; diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index d9e0a9ab..1f3ef582 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -11,12 +11,12 @@ pub struct ParamsConv1D { pub(crate) k_size: usize, pub(crate) padding: usize, pub(crate) stride: usize, + pub(crate) dilation: usize, } impl ParamsConv1D { pub(crate) fn l_out(&self) -> usize { - let dilation = 1; - (self.l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1 + (self.l_in + 2 * self.padding - self.dilation * (self.k_size - 1) - 1) / self.stride + 1 } pub(crate) fn out_dims(&self) -> Vec<usize> { @@ -36,17 +36,16 @@ pub struct ParamsConv2D { pub(crate) c_in: usize, pub(crate) padding: usize, pub(crate) stride: usize, + pub(crate) dilation: usize, } impl ParamsConv2D { pub(crate) fn out_h(&self) -> usize { - let dilation = 1; - (self.i_h + 2 * self.padding - dilation * (self.k_h - 1) - 1) / self.stride + 1 + (self.i_h + 2 * self.padding - self.dilation * (self.k_h - 1) - 1) / self.stride + 1 } pub(crate) fn out_w(&self) -> usize { - let dilation = 1; - (self.i_w + 2 * self.padding - dilation * (self.k_w - 1) - 1) / self.stride + 1 + (self.i_w + 2 * self.padding - self.dilation * (self.k_w - 1) - 1) / self.stride + 1 } pub(crate) fn out_dims(&self) -> Vec<usize> { @@ -66,18 +65,17 @@ pub struct ParamsConvTranspose2D { pub(crate) padding: usize, pub(crate) output_padding: usize, pub(crate) stride: usize, + pub(crate) dilation: usize, } impl ParamsConvTranspose2D { pub(crate) fn out_h(&self) -> usize { - let dilation = 1; - (self.i_h - 1) * self.stride + dilation * (self.k_h - 1) + self.output_padding + 1 + (self.i_h - 1) * self.stride + self.dilation * (self.k_h - 1) + self.output_padding + 1 - 2 * self.padding } pub(crate) fn out_w(&self) -> usize { - let dilation = 1; - (self.i_w - 1) * self.stride + dilation * (self.k_w - 1) + self.output_padding + 1 + (self.i_w - 1) * self.stride + self.dilation * (self.k_w - 1) + self.output_padding + 1 - 2 * self.padding } @@ -96,6 +94,7 @@ impl Tensor { kernel, padding: params.padding, stride: params.stride, + dilation: params.dilation, }); let out_dims = params.out_dims(); Ok(crate::tensor::from_storage(storage, out_dims, op, false)) @@ -107,6 +106,7 @@ impl Tensor { kernel: &Self, padding: usize, stride: usize, + dilation: usize, groups: usize, ) -> Result<Self> { let (c_out, c_in_k, k_size) = kernel.dims3()?; @@ -130,6 +130,7 @@ impl Tensor { k_size, padding, stride, + dilation, }; if groups == 1 { self.conv1d_single_group(kernel, ¶ms) @@ -154,6 +155,7 @@ impl Tensor { kernel, padding: params.padding, stride: params.stride, + dilation: params.dilation, }); let out_dims = params.out_dims(); Ok(crate::tensor::from_storage(storage, out_dims, op, false)) @@ -165,6 +167,7 @@ impl Tensor { kernel: &Self, padding: usize, stride: usize, + dilation: usize, groups: usize, ) -> Result<Self> { let (b_size, c_in, i_h, i_w) = self.dims4()?; @@ -184,6 +187,7 @@ impl Tensor { c_in: c_in / groups, padding, stride, + dilation, }; if groups == 1 { self.conv2d_single_group(kernel, ¶ms) @@ -206,6 +210,7 @@ impl Tensor { padding: usize, output_padding: usize, stride: usize, + dilation: usize, ) -> Result<Self> { let (b_size, c_in, i_h, i_w) = self.dims4()?; let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?; @@ -223,6 +228,7 @@ impl Tensor { padding, output_padding, stride, + dilation, }; let storage = self.storage().conv_transpose2d( self.layout(), @@ -236,6 +242,7 @@ impl Tensor { padding: params.padding, output_padding: params.output_padding, stride: params.stride, + dilation: params.dilation, }); let out_dims = params.out_dims(); Ok(crate::tensor::from_storage(storage, out_dims, op, false)) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index f52d53b1..60fac0c9 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1064,7 +1064,7 @@ impl<'a> Map2 for Conv1D<'a> { let dst_idx = dst_idx + b_idx * p.c_out * l_out; for dst_l in 0..l_out { let dst_idx = dst_idx + dst_l; - let src_l = p.stride * dst_l + offset; + let src_l = (p.stride * dst_l + offset) * p.dilation; if src_l < p.padding || src_l >= p.padding + p.l_in { continue; } @@ -1141,14 +1141,14 @@ impl<'a> Map2 for Conv2D<'a> { let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w; for dst_h in 0..out_h { let dst_idx = dst_idx + dst_h * out_w; - let src_h = p.stride * dst_h + offset_h; + let src_h = (p.stride * dst_h + offset_h) * p.dilation; if src_h < p.padding || src_h >= p.i_h + p.padding { continue; } let src_h = src_h - p.padding; for dst_w in 0..out_w { let dst_idx = dst_idx + dst_w; - let src_w = p.stride * dst_w + offset_w; + let src_w = (p.stride * dst_w + offset_w) * p.dilation; if src_w < p.padding || src_w >= p.i_w + p.padding { continue; } @@ -1186,6 +1186,12 @@ impl<'a> Map2 for ConvTranspose2D<'a> { const OP: &'static str = "conv_transpose2d"; fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> { let p = self.0; + if p.dilation != 1 { + crate::bail!( + "dilation {} is not supported for conv-transpose2d", + p.dilation + ) + } let inp = &inp[inp_l.start_offset()..]; let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?; let k = &k[k_l.start_offset()..]; diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index ed696368..cd06e8d7 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -960,7 +960,9 @@ impl<'a> Map2 for Conv1D<'a> { crate::bail!("unexpected input shape for conv1d {dims:?}") }; let ds = dev.htod_copy(ds).w()?; - let params = (el, l_out, p.stride, p.padding, &ds, inp, k, &out); + let params = ( + el, l_out, p.stride, p.padding, p.dilation, &ds, inp, k, &out, + ); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(out) @@ -998,7 +1000,9 @@ impl<'a> Map2 for Conv2D<'a> { crate::bail!("unexpected input shape for conv2d {dims:?}") }; let ds = dev.htod_copy(ds).w()?; - let params = (el, out_w, out_h, p.stride, p.padding, &ds, inp, k, &out); + let params = ( + el, out_w, out_h, p.stride, p.padding, p.dilation, &ds, inp, k, &out, + ); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(out) @@ -1018,6 +1022,12 @@ impl<'a> Map2 for ConvTranspose2D<'a> { // Kernel shape: (c_in_k, c_out, h_k, w_k) // Input shape: (b_size, c_in, h_in, w_in) let p = &self.0; + if p.dilation != 1 { + crate::bail!( + "dilation {} is not supported for conv-transpose2d", + p.dilation + ) + } let (out_w, out_h) = (p.out_w(), p.out_h()); let dst_el = p.c_out * out_w * out_h * p.b_size; let inp = &inp.slice(inp_l.start_offset()..); @@ -1043,6 +1053,7 @@ impl<'a> Map2 for ConvTranspose2D<'a> { p.stride, p.padding, p.output_padding, + p.dilation, &ds, inp, k, diff --git a/candle-core/src/cudnn.rs b/candle-core/src/cudnn.rs index 3e943e51..235ad6e3 100644 --- a/candle-core/src/cudnn.rs +++ b/candle-core/src/cudnn.rs @@ -48,7 +48,7 @@ pub(crate) fn launch_conv2d< let conv = cudnn.create_conv2d::<T>( /* pad */ [params.padding as i32, params.padding as i32], /* stride */ [params.stride as i32, params.stride as i32], - /* dilation */ [1, 1], + /* dilation */ [params.dilation as i32, params.dilation as i32], cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION, )?; let x_shape = [ diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 65232839..84716249 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -81,6 +81,26 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray } } +impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4: usize> NdArray + for &[[[[S; N4]; N3]; N2]; N1] +{ + fn shape(&self) -> Result<Shape> { + Ok(Shape::from((N1, N2, N3, N4))) + } + + fn to_cpu_storage(&self) -> CpuStorage { + let mut vec = Vec::with_capacity(N1 * N2 * N3 * N4); + for i1 in 0..N1 { + for i2 in 0..N2 { + for i3 in 0..N3 { + vec.extend(self[i1][i2][i3]) + } + } + } + S::to_cpu_storage_owned(vec) + } +} + impl Device { pub fn new_cuda(ordinal: usize) -> Result<Self> { Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?)) diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index b18f868d..3fe52ebc 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -81,6 +81,7 @@ pub enum Op { kernel: Tensor, padding: usize, stride: usize, + dilation: usize, }, #[allow(dead_code)] @@ -89,6 +90,7 @@ pub enum Op { kernel: Tensor, padding: usize, stride: usize, + dilation: usize, }, #[allow(dead_code)] @@ -98,6 +100,7 @@ pub enum Op { padding: usize, output_padding: usize, stride: usize, + dilation: usize, }, AvgPool2D { |