summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/backprop.rs15
-rw-r--r--candle-core/src/conv.rs27
-rw-r--r--candle-core/src/cpu_backend.rs12
-rw-r--r--candle-core/src/cuda_backend.rs15
-rw-r--r--candle-core/src/cudnn.rs2
-rw-r--r--candle-core/src/device.rs20
-rw-r--r--candle-core/src/op.rs3
7 files changed, 74 insertions, 20 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 9ecdee4f..f4f90373 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -197,21 +197,28 @@ impl Tensor {
kernel,
padding,
stride,
+ dilation,
} => {
// The output height for conv_transpose2d is:
// (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1
let grad_h = grad.dim(2)?;
let k_h = kernel.dim(2)?;
- let out_size = (grad_h - 1) * stride + (k_h - 1) + 1 - 2 * padding;
+ let out_size =
+ (grad_h - 1) * stride + dilation * (k_h - 1) + 1 - 2 * padding;
let out_padding = arg.dim(2)? - out_size;
- let grad_arg =
- grad.conv_transpose2d(kernel, *padding, out_padding, *stride)?;
+ let grad_arg = grad.conv_transpose2d(
+ kernel,
+ *padding,
+ out_padding,
+ *stride,
+ *dilation,
+ )?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
let grad_kernel = arg
.transpose(0, 1)?
- .conv2d(&grad.transpose(0, 1)?, *padding, *stride, 1)?
+ .conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
*sum_grad = sum_grad.add(&grad_kernel)?;
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs
index d9e0a9ab..1f3ef582 100644
--- a/candle-core/src/conv.rs
+++ b/candle-core/src/conv.rs
@@ -11,12 +11,12 @@ pub struct ParamsConv1D {
pub(crate) k_size: usize,
pub(crate) padding: usize,
pub(crate) stride: usize,
+ pub(crate) dilation: usize,
}
impl ParamsConv1D {
pub(crate) fn l_out(&self) -> usize {
- let dilation = 1;
- (self.l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1
+ (self.l_in + 2 * self.padding - self.dilation * (self.k_size - 1) - 1) / self.stride + 1
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
@@ -36,17 +36,16 @@ pub struct ParamsConv2D {
pub(crate) c_in: usize,
pub(crate) padding: usize,
pub(crate) stride: usize,
+ pub(crate) dilation: usize,
}
impl ParamsConv2D {
pub(crate) fn out_h(&self) -> usize {
- let dilation = 1;
- (self.i_h + 2 * self.padding - dilation * (self.k_h - 1) - 1) / self.stride + 1
+ (self.i_h + 2 * self.padding - self.dilation * (self.k_h - 1) - 1) / self.stride + 1
}
pub(crate) fn out_w(&self) -> usize {
- let dilation = 1;
- (self.i_w + 2 * self.padding - dilation * (self.k_w - 1) - 1) / self.stride + 1
+ (self.i_w + 2 * self.padding - self.dilation * (self.k_w - 1) - 1) / self.stride + 1
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
@@ -66,18 +65,17 @@ pub struct ParamsConvTranspose2D {
pub(crate) padding: usize,
pub(crate) output_padding: usize,
pub(crate) stride: usize,
+ pub(crate) dilation: usize,
}
impl ParamsConvTranspose2D {
pub(crate) fn out_h(&self) -> usize {
- let dilation = 1;
- (self.i_h - 1) * self.stride + dilation * (self.k_h - 1) + self.output_padding + 1
+ (self.i_h - 1) * self.stride + self.dilation * (self.k_h - 1) + self.output_padding + 1
- 2 * self.padding
}
pub(crate) fn out_w(&self) -> usize {
- let dilation = 1;
- (self.i_w - 1) * self.stride + dilation * (self.k_w - 1) + self.output_padding + 1
+ (self.i_w - 1) * self.stride + self.dilation * (self.k_w - 1) + self.output_padding + 1
- 2 * self.padding
}
@@ -96,6 +94,7 @@ impl Tensor {
kernel,
padding: params.padding,
stride: params.stride,
+ dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
@@ -107,6 +106,7 @@ impl Tensor {
kernel: &Self,
padding: usize,
stride: usize,
+ dilation: usize,
groups: usize,
) -> Result<Self> {
let (c_out, c_in_k, k_size) = kernel.dims3()?;
@@ -130,6 +130,7 @@ impl Tensor {
k_size,
padding,
stride,
+ dilation,
};
if groups == 1 {
self.conv1d_single_group(kernel, &params)
@@ -154,6 +155,7 @@ impl Tensor {
kernel,
padding: params.padding,
stride: params.stride,
+ dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
@@ -165,6 +167,7 @@ impl Tensor {
kernel: &Self,
padding: usize,
stride: usize,
+ dilation: usize,
groups: usize,
) -> Result<Self> {
let (b_size, c_in, i_h, i_w) = self.dims4()?;
@@ -184,6 +187,7 @@ impl Tensor {
c_in: c_in / groups,
padding,
stride,
+ dilation,
};
if groups == 1 {
self.conv2d_single_group(kernel, &params)
@@ -206,6 +210,7 @@ impl Tensor {
padding: usize,
output_padding: usize,
stride: usize,
+ dilation: usize,
) -> Result<Self> {
let (b_size, c_in, i_h, i_w) = self.dims4()?;
let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?;
@@ -223,6 +228,7 @@ impl Tensor {
padding,
output_padding,
stride,
+ dilation,
};
let storage = self.storage().conv_transpose2d(
self.layout(),
@@ -236,6 +242,7 @@ impl Tensor {
padding: params.padding,
output_padding: params.output_padding,
stride: params.stride,
+ dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index f52d53b1..60fac0c9 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -1064,7 +1064,7 @@ impl<'a> Map2 for Conv1D<'a> {
let dst_idx = dst_idx + b_idx * p.c_out * l_out;
for dst_l in 0..l_out {
let dst_idx = dst_idx + dst_l;
- let src_l = p.stride * dst_l + offset;
+ let src_l = (p.stride * dst_l + offset) * p.dilation;
if src_l < p.padding || src_l >= p.padding + p.l_in {
continue;
}
@@ -1141,14 +1141,14 @@ impl<'a> Map2 for Conv2D<'a> {
let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w;
for dst_h in 0..out_h {
let dst_idx = dst_idx + dst_h * out_w;
- let src_h = p.stride * dst_h + offset_h;
+ let src_h = (p.stride * dst_h + offset_h) * p.dilation;
if src_h < p.padding || src_h >= p.i_h + p.padding {
continue;
}
let src_h = src_h - p.padding;
for dst_w in 0..out_w {
let dst_idx = dst_idx + dst_w;
- let src_w = p.stride * dst_w + offset_w;
+ let src_w = (p.stride * dst_w + offset_w) * p.dilation;
if src_w < p.padding || src_w >= p.i_w + p.padding {
continue;
}
@@ -1186,6 +1186,12 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
const OP: &'static str = "conv_transpose2d";
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
+ if p.dilation != 1 {
+ crate::bail!(
+ "dilation {} is not supported for conv-transpose2d",
+ p.dilation
+ )
+ }
let inp = &inp[inp_l.start_offset()..];
let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
let k = &k[k_l.start_offset()..];
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index ed696368..cd06e8d7 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -960,7 +960,9 @@ impl<'a> Map2 for Conv1D<'a> {
crate::bail!("unexpected input shape for conv1d {dims:?}")
};
let ds = dev.htod_copy(ds).w()?;
- let params = (el, l_out, p.stride, p.padding, &ds, inp, k, &out);
+ let params = (
+ el, l_out, p.stride, p.padding, p.dilation, &ds, inp, k, &out,
+ );
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
@@ -998,7 +1000,9 @@ impl<'a> Map2 for Conv2D<'a> {
crate::bail!("unexpected input shape for conv2d {dims:?}")
};
let ds = dev.htod_copy(ds).w()?;
- let params = (el, out_w, out_h, p.stride, p.padding, &ds, inp, k, &out);
+ let params = (
+ el, out_w, out_h, p.stride, p.padding, p.dilation, &ds, inp, k, &out,
+ );
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
@@ -1018,6 +1022,12 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
// Kernel shape: (c_in_k, c_out, h_k, w_k)
// Input shape: (b_size, c_in, h_in, w_in)
let p = &self.0;
+ if p.dilation != 1 {
+ crate::bail!(
+ "dilation {} is not supported for conv-transpose2d",
+ p.dilation
+ )
+ }
let (out_w, out_h) = (p.out_w(), p.out_h());
let dst_el = p.c_out * out_w * out_h * p.b_size;
let inp = &inp.slice(inp_l.start_offset()..);
@@ -1043,6 +1053,7 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
p.stride,
p.padding,
p.output_padding,
+ p.dilation,
&ds,
inp,
k,
diff --git a/candle-core/src/cudnn.rs b/candle-core/src/cudnn.rs
index 3e943e51..235ad6e3 100644
--- a/candle-core/src/cudnn.rs
+++ b/candle-core/src/cudnn.rs
@@ -48,7 +48,7 @@ pub(crate) fn launch_conv2d<
let conv = cudnn.create_conv2d::<T>(
/* pad */ [params.padding as i32, params.padding as i32],
/* stride */ [params.stride as i32, params.stride as i32],
- /* dilation */ [1, 1],
+ /* dilation */ [params.dilation as i32, params.dilation as i32],
cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
)?;
let x_shape = [
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs
index 65232839..84716249 100644
--- a/candle-core/src/device.rs
+++ b/candle-core/src/device.rs
@@ -81,6 +81,26 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
}
}
+impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4: usize> NdArray
+ for &[[[[S; N4]; N3]; N2]; N1]
+{
+ fn shape(&self) -> Result<Shape> {
+ Ok(Shape::from((N1, N2, N3, N4)))
+ }
+
+ fn to_cpu_storage(&self) -> CpuStorage {
+ let mut vec = Vec::with_capacity(N1 * N2 * N3 * N4);
+ for i1 in 0..N1 {
+ for i2 in 0..N2 {
+ for i3 in 0..N3 {
+ vec.extend(self[i1][i2][i3])
+ }
+ }
+ }
+ S::to_cpu_storage_owned(vec)
+ }
+}
+
impl Device {
pub fn new_cuda(ordinal: usize) -> Result<Self> {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index b18f868d..3fe52ebc 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -81,6 +81,7 @@ pub enum Op {
kernel: Tensor,
padding: usize,
stride: usize,
+ dilation: usize,
},
#[allow(dead_code)]
@@ -89,6 +90,7 @@ pub enum Op {
kernel: Tensor,
padding: usize,
stride: usize,
+ dilation: usize,
},
#[allow(dead_code)]
@@ -98,6 +100,7 @@ pub enum Op {
padding: usize,
output_padding: usize,
stride: usize,
+ dilation: usize,
},
AvgPool2D {