summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/examples/basics.rs2
-rw-r--r--candle-core/examples/cpu_benchmarks.rs4
-rw-r--r--candle-core/examples/cuda_basics.rs6
-rw-r--r--candle-core/src/backprop.rs15
-rw-r--r--candle-core/src/conv.rs27
-rw-r--r--candle-core/src/cpu_backend.rs12
-rw-r--r--candle-core/src/cuda_backend.rs15
-rw-r--r--candle-core/src/cudnn.rs2
-rw-r--r--candle-core/src/device.rs20
-rw-r--r--candle-core/src/op.rs3
-rw-r--r--candle-core/tests/conv_tests.rs135
-rw-r--r--candle-examples/examples/musicgen/encodec_model.rs2
-rw-r--r--candle-examples/examples/stable-diffusion/resnet.rs2
-rw-r--r--candle-examples/examples/whisper/model.rs2
-rw-r--r--candle-examples/examples/yolo-v3/darknet.rs1
-rw-r--r--candle-examples/examples/yolo-v8/model.rs1
-rw-r--r--candle-kernels/src/conv.cu18
-rw-r--r--candle-nn/src/conv.rs6
-rw-r--r--candle-wasm-examples/whisper/src/model.rs2
-rw-r--r--candle-wasm-examples/yolo/src/model.rs1
20 files changed, 231 insertions, 45 deletions
diff --git a/candle-core/examples/basics.rs b/candle-core/examples/basics.rs
index 9d4734de..ad008177 100644
--- a/candle-core/examples/basics.rs
+++ b/candle-core/examples/basics.rs
@@ -11,7 +11,7 @@ fn main() -> Result<()> {
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
let start = std::time::Instant::now();
- let res = inp.conv2d(&w, 0, 1, 1)?;
+ let res = inp.conv2d(&w, 0, 1, 1, 1)?;
println!("{:?}", start.elapsed());
println!("{res:?}");
Ok(())
diff --git a/candle-core/examples/cpu_benchmarks.rs b/candle-core/examples/cpu_benchmarks.rs
index 1ebd9b75..13175ac1 100644
--- a/candle-core/examples/cpu_benchmarks.rs
+++ b/candle-core/examples/cpu_benchmarks.rs
@@ -40,7 +40,7 @@ impl Benchmark for Conv1d {
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
- d.0.conv1d(&d.1, 0, 1, 1)
+ d.0.conv1d(&d.1, 0, 1, 1, 1)
}
const ITERS: usize = 5;
@@ -59,7 +59,7 @@ impl Benchmark for Conv2d {
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
- d.0.conv2d(&d.1, 0, 1, 1)
+ d.0.conv2d(&d.1, 0, 1, 1, 1)
}
const ITERS: usize = 1;
diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs
index cbdafd64..ad207461 100644
--- a/candle-core/examples/cuda_basics.rs
+++ b/candle-core/examples/cuda_basics.rs
@@ -11,11 +11,11 @@ fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
let in_t = Tensor::rand(-1f32, 1f32, (1, 3, 12, 7), &device)?;
let k_t = Tensor::rand(-1f32, 1f32, (6, 3, 1, 1), &device)?;
- let out_t = in_t.conv2d(&k_t, 0, 1, 1)?;
+ let out_t = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
println!("{out_t}");
let in_t = in_t.to_device(&Device::Cpu)?;
let k_t = k_t.to_device(&Device::Cpu)?;
- let out_t2 = in_t.conv2d(&k_t, 0, 1, 1)?;
+ let out_t2 = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
let diff = (out_t.to_device(&Device::Cpu)? - out_t2)?
.sqr()?
.sum_all()?;
@@ -23,7 +23,7 @@ fn main() -> Result<()> {
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
- let res = t.conv2d(&w, 1, 1, 1)?;
+ let res = t.conv2d(&w, 1, 1, 1, 1)?;
println!("{res:?}");
Ok(())
}
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 9ecdee4f..f4f90373 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -197,21 +197,28 @@ impl Tensor {
kernel,
padding,
stride,
+ dilation,
} => {
// The output height for conv_transpose2d is:
// (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1
let grad_h = grad.dim(2)?;
let k_h = kernel.dim(2)?;
- let out_size = (grad_h - 1) * stride + (k_h - 1) + 1 - 2 * padding;
+ let out_size =
+ (grad_h - 1) * stride + dilation * (k_h - 1) + 1 - 2 * padding;
let out_padding = arg.dim(2)? - out_size;
- let grad_arg =
- grad.conv_transpose2d(kernel, *padding, out_padding, *stride)?;
+ let grad_arg = grad.conv_transpose2d(
+ kernel,
+ *padding,
+ out_padding,
+ *stride,
+ *dilation,
+ )?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
let grad_kernel = arg
.transpose(0, 1)?
- .conv2d(&grad.transpose(0, 1)?, *padding, *stride, 1)?
+ .conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
*sum_grad = sum_grad.add(&grad_kernel)?;
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs
index d9e0a9ab..1f3ef582 100644
--- a/candle-core/src/conv.rs
+++ b/candle-core/src/conv.rs
@@ -11,12 +11,12 @@ pub struct ParamsConv1D {
pub(crate) k_size: usize,
pub(crate) padding: usize,
pub(crate) stride: usize,
+ pub(crate) dilation: usize,
}
impl ParamsConv1D {
pub(crate) fn l_out(&self) -> usize {
- let dilation = 1;
- (self.l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1
+ (self.l_in + 2 * self.padding - self.dilation * (self.k_size - 1) - 1) / self.stride + 1
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
@@ -36,17 +36,16 @@ pub struct ParamsConv2D {
pub(crate) c_in: usize,
pub(crate) padding: usize,
pub(crate) stride: usize,
+ pub(crate) dilation: usize,
}
impl ParamsConv2D {
pub(crate) fn out_h(&self) -> usize {
- let dilation = 1;
- (self.i_h + 2 * self.padding - dilation * (self.k_h - 1) - 1) / self.stride + 1
+ (self.i_h + 2 * self.padding - self.dilation * (self.k_h - 1) - 1) / self.stride + 1
}
pub(crate) fn out_w(&self) -> usize {
- let dilation = 1;
- (self.i_w + 2 * self.padding - dilation * (self.k_w - 1) - 1) / self.stride + 1
+ (self.i_w + 2 * self.padding - self.dilation * (self.k_w - 1) - 1) / self.stride + 1
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
@@ -66,18 +65,17 @@ pub struct ParamsConvTranspose2D {
pub(crate) padding: usize,
pub(crate) output_padding: usize,
pub(crate) stride: usize,
+ pub(crate) dilation: usize,
}
impl ParamsConvTranspose2D {
pub(crate) fn out_h(&self) -> usize {
- let dilation = 1;
- (self.i_h - 1) * self.stride + dilation * (self.k_h - 1) + self.output_padding + 1
+ (self.i_h - 1) * self.stride + self.dilation * (self.k_h - 1) + self.output_padding + 1
- 2 * self.padding
}
pub(crate) fn out_w(&self) -> usize {
- let dilation = 1;
- (self.i_w - 1) * self.stride + dilation * (self.k_w - 1) + self.output_padding + 1
+ (self.i_w - 1) * self.stride + self.dilation * (self.k_w - 1) + self.output_padding + 1
- 2 * self.padding
}
@@ -96,6 +94,7 @@ impl Tensor {
kernel,
padding: params.padding,
stride: params.stride,
+ dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
@@ -107,6 +106,7 @@ impl Tensor {
kernel: &Self,
padding: usize,
stride: usize,
+ dilation: usize,
groups: usize,
) -> Result<Self> {
let (c_out, c_in_k, k_size) = kernel.dims3()?;
@@ -130,6 +130,7 @@ impl Tensor {
k_size,
padding,
stride,
+ dilation,
};
if groups == 1 {
self.conv1d_single_group(kernel, &params)
@@ -154,6 +155,7 @@ impl Tensor {
kernel,
padding: params.padding,
stride: params.stride,
+ dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
@@ -165,6 +167,7 @@ impl Tensor {
kernel: &Self,
padding: usize,
stride: usize,
+ dilation: usize,
groups: usize,
) -> Result<Self> {
let (b_size, c_in, i_h, i_w) = self.dims4()?;
@@ -184,6 +187,7 @@ impl Tensor {
c_in: c_in / groups,
padding,
stride,
+ dilation,
};
if groups == 1 {
self.conv2d_single_group(kernel, &params)
@@ -206,6 +210,7 @@ impl Tensor {
padding: usize,
output_padding: usize,
stride: usize,
+ dilation: usize,
) -> Result<Self> {
let (b_size, c_in, i_h, i_w) = self.dims4()?;
let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?;
@@ -223,6 +228,7 @@ impl Tensor {
padding,
output_padding,
stride,
+ dilation,
};
let storage = self.storage().conv_transpose2d(
self.layout(),
@@ -236,6 +242,7 @@ impl Tensor {
padding: params.padding,
output_padding: params.output_padding,
stride: params.stride,
+ dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index f52d53b1..60fac0c9 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -1064,7 +1064,7 @@ impl<'a> Map2 for Conv1D<'a> {
let dst_idx = dst_idx + b_idx * p.c_out * l_out;
for dst_l in 0..l_out {
let dst_idx = dst_idx + dst_l;
- let src_l = p.stride * dst_l + offset;
+ let src_l = (p.stride * dst_l + offset) * p.dilation;
if src_l < p.padding || src_l >= p.padding + p.l_in {
continue;
}
@@ -1141,14 +1141,14 @@ impl<'a> Map2 for Conv2D<'a> {
let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w;
for dst_h in 0..out_h {
let dst_idx = dst_idx + dst_h * out_w;
- let src_h = p.stride * dst_h + offset_h;
+ let src_h = (p.stride * dst_h + offset_h) * p.dilation;
if src_h < p.padding || src_h >= p.i_h + p.padding {
continue;
}
let src_h = src_h - p.padding;
for dst_w in 0..out_w {
let dst_idx = dst_idx + dst_w;
- let src_w = p.stride * dst_w + offset_w;
+ let src_w = (p.stride * dst_w + offset_w) * p.dilation;
if src_w < p.padding || src_w >= p.i_w + p.padding {
continue;
}
@@ -1186,6 +1186,12 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
const OP: &'static str = "conv_transpose2d";
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
+ if p.dilation != 1 {
+ crate::bail!(
+ "dilation {} is not supported for conv-transpose2d",
+ p.dilation
+ )
+ }
let inp = &inp[inp_l.start_offset()..];
let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
let k = &k[k_l.start_offset()..];
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index ed696368..cd06e8d7 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -960,7 +960,9 @@ impl<'a> Map2 for Conv1D<'a> {
crate::bail!("unexpected input shape for conv1d {dims:?}")
};
let ds = dev.htod_copy(ds).w()?;
- let params = (el, l_out, p.stride, p.padding, &ds, inp, k, &out);
+ let params = (
+ el, l_out, p.stride, p.padding, p.dilation, &ds, inp, k, &out,
+ );
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
@@ -998,7 +1000,9 @@ impl<'a> Map2 for Conv2D<'a> {
crate::bail!("unexpected input shape for conv2d {dims:?}")
};
let ds = dev.htod_copy(ds).w()?;
- let params = (el, out_w, out_h, p.stride, p.padding, &ds, inp, k, &out);
+ let params = (
+ el, out_w, out_h, p.stride, p.padding, p.dilation, &ds, inp, k, &out,
+ );
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
@@ -1018,6 +1022,12 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
// Kernel shape: (c_in_k, c_out, h_k, w_k)
// Input shape: (b_size, c_in, h_in, w_in)
let p = &self.0;
+ if p.dilation != 1 {
+ crate::bail!(
+ "dilation {} is not supported for conv-transpose2d",
+ p.dilation
+ )
+ }
let (out_w, out_h) = (p.out_w(), p.out_h());
let dst_el = p.c_out * out_w * out_h * p.b_size;
let inp = &inp.slice(inp_l.start_offset()..);
@@ -1043,6 +1053,7 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
p.stride,
p.padding,
p.output_padding,
+ p.dilation,
&ds,
inp,
k,
diff --git a/candle-core/src/cudnn.rs b/candle-core/src/cudnn.rs
index 3e943e51..235ad6e3 100644
--- a/candle-core/src/cudnn.rs
+++ b/candle-core/src/cudnn.rs
@@ -48,7 +48,7 @@ pub(crate) fn launch_conv2d<
let conv = cudnn.create_conv2d::<T>(
/* pad */ [params.padding as i32, params.padding as i32],
/* stride */ [params.stride as i32, params.stride as i32],
- /* dilation */ [1, 1],
+ /* dilation */ [params.dilation as i32, params.dilation as i32],
cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
)?;
let x_shape = [
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs
index 65232839..84716249 100644
--- a/candle-core/src/device.rs
+++ b/candle-core/src/device.rs
@@ -81,6 +81,26 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
}
}
+impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4: usize> NdArray
+ for &[[[[S; N4]; N3]; N2]; N1]
+{
+ fn shape(&self) -> Result<Shape> {
+ Ok(Shape::from((N1, N2, N3, N4)))
+ }
+
+ fn to_cpu_storage(&self) -> CpuStorage {
+ let mut vec = Vec::with_capacity(N1 * N2 * N3 * N4);
+ for i1 in 0..N1 {
+ for i2 in 0..N2 {
+ for i3 in 0..N3 {
+ vec.extend(self[i1][i2][i3])
+ }
+ }
+ }
+ S::to_cpu_storage_owned(vec)
+ }
+}
+
impl Device {
pub fn new_cuda(ordinal: usize) -> Result<Self> {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index b18f868d..3fe52ebc 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -81,6 +81,7 @@ pub enum Op {
kernel: Tensor,
padding: usize,
stride: usize,
+ dilation: usize,
},
#[allow(dead_code)]
@@ -89,6 +90,7 @@ pub enum Op {
kernel: Tensor,
padding: usize,
stride: usize,
+ dilation: usize,
},
#[allow(dead_code)]
@@ -98,6 +100,7 @@ pub enum Op {
padding: usize,
output_padding: usize,
stride: usize,
+ dilation: usize,
},
AvgPool2D {
diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs
index 1c378e5e..05015995 100644
--- a/candle-core/tests/conv_tests.rs
+++ b/candle-core/tests/conv_tests.rs
@@ -32,13 +32,13 @@ fn conv1d(dev: &Device) -> Result<()> {
dev,
)?
.reshape((2, 4, 3))?;
- let res = t.conv1d(&w, 0, 1, 1)?;
+ let res = t.conv1d(&w, 0, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.6357, -1.3336, 4.1393, -1.1784, 3.5675, 0.5069]
);
- let res = t.conv1d(&w, /*padding*/ 1, 1, 1)?;
+ let res = t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 5]);
// Same as pytorch default padding: use zeros.
assert_eq!(
@@ -51,13 +51,13 @@ fn conv1d(dev: &Device) -> Result<()> {
fn conv1d_small(dev: &Device) -> Result<()> {
let t = Tensor::new(&[0.4056f32, -0.8689, -0.0773, -1.5630], dev)?.reshape((1, 1, 4))?;
let w = Tensor::new(&[1f32, 0., 0.], dev)?.reshape((1, 1, 3))?;
- let res = t.conv1d(&w, 0, 1, 1)?;
+ let res = t.conv1d(&w, 0, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 2]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.4056, -0.8689]
);
- let res = t.conv1d(&w, /*padding*/ 1, 1, 1)?;
+ let res = t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 4]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@@ -81,6 +81,10 @@ w_t = w.transpose(0, 1)
res = torch.nn.functional.conv_transpose2d(t, w_t)
print(res.shape)
print(res)
+
+res = torch.nn.functional.conv2d(t, w, dilation=2)
+print(res.shape)
+print(res[0])
*/
fn conv2d(dev: &Device) -> Result<()> {
let t = Tensor::new(
@@ -113,7 +117,7 @@ fn conv2d(dev: &Device) -> Result<()> {
)?;
let t = t.reshape((1, 4, 5, 5))?;
let w = w.reshape((2, 4, 3, 3))?;
- let res = t.conv2d(&w, 0, 1, 1)?;
+ let res = t.conv2d(&w, 0, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@@ -122,7 +126,7 @@ fn conv2d(dev: &Device) -> Result<()> {
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
]
);
- let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1)?;
+ let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 7, 7]);
assert_eq!(
test_utils::to_vec3_round(&res.i(0)?, 4)?,
@@ -147,6 +151,13 @@ fn conv2d(dev: &Device) -> Result<()> {
]
]
);
+ // Dilations.
+ let res = t.conv2d(&w, 0, 1, 2, 1)?;
+ assert_eq!(res.dims(), [1, 2, 1, 1]);
+ assert_eq!(
+ test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
+ [2.45, -2.3504],
+ );
Ok(())
}
@@ -182,13 +193,13 @@ fn conv2d_small(dev: &Device) -> Result<()> {
let w = Tensor::new(&[-0.9259f32, 1.3017], dev)?;
let t = t.reshape((1, 2, 3, 3))?;
let w = w.reshape((1, 2, 1, 1))?;
- let res = t.conv2d(&w, 0, 1, 1)?;
+ let res = t.conv2d(&w, 0, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539]
);
- let res = t.conv2d(&w, 2, 1, 1)?;
+ let res = t.conv2d(&w, 2, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 7, 7]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@@ -200,13 +211,13 @@ fn conv2d_small(dev: &Device) -> Result<()> {
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
]
);
- let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1)?;
+ let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539],
);
- let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1)?;
+ let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [2, 2, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@@ -230,7 +241,7 @@ fn conv2d_smaller(dev: &Device) -> Result<()> {
let w = Tensor::new(&[1f32, 1., 1., 1., 1., 1., 1., 1., 1.], dev)?;
let t = t.reshape((1, 1, 3, 3))?;
let w = w.reshape((1, 1, 3, 3))?;
- let res = t.conv2d(&w, 0, 1, 1)?;
+ let res = t.conv2d(&w, 0, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 1, 1]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@@ -261,7 +272,7 @@ fn conv2d_non_square(dev: &Device) -> Result<()> {
let w = Tensor::new(&[-1.1351f32, 1.3841], dev)?;
let t = t.reshape((1, 2, 4, 2))?;
let w = w.reshape((1, 2, 1, 1))?;
- let res = t.conv2d(&w, 0, 1, 1)?;
+ let res = t.conv2d(&w, 0, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 4, 2]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@@ -270,6 +281,36 @@ fn conv2d_non_square(dev: &Device) -> Result<()> {
Ok(())
}
+/*
+import torch
+torch.manual_seed(4242)
+
+t = torch.randn((1, 4, 5, 5), requires_grad=True)
+w = torch.randn((2, 4, 3, 3), requires_grad=True)
+print(t.flatten())
+print(w.flatten())
+res = torch.nn.functional.conv2d(t, w)
+print(res.flatten())
+loss = (res ** 2).sum()
+print(loss)
+loss.backward()
+print(t.grad.shape)
+print(t.grad.flatten())
+print(w.grad.shape)
+print(w.grad.flatten())
+
+t.grad.zero_()
+w.grad.zero_()
+res = torch.nn.functional.conv2d(t, w, stride=2)
+print(res.flatten())
+loss = (res ** 2).sum()
+print(loss)
+loss.backward()
+print(t.grad.shape)
+print(t.grad[0])
+print(w.grad.shape)
+print(w.grad[0])
+*/
fn conv2d_grad(dev: &Device) -> Result<()> {
use candle_core::Var;
let t = Var::from_slice(
@@ -302,7 +343,7 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
(2, 4, 3, 3),
dev,
)?;
- let res = t.conv2d(&w, 0, 1, 1)?;
+ let res = t.conv2d(&w, 0, 1, 1, 1)?;
let loss = res.sqr()?.sum_all()?;
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 741.12f32);
let grads = loss.backward()?;
@@ -335,6 +376,74 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
-34.24, -16.36, -12.31, 1.92, -33.62, -14.1, -49.23, -7.39, 11.5, -9.98, 9.66, 29.6
]
);
+
+ // Same as before but with stride.
+ let res = t.conv2d(&w, 0, 2, 1, 1)?;
+ let loss = res.sqr()?.sum_all()?;
+ assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 277.16f32);
+ let grads = loss.backward()?;
+ let grad_t = grads.get(&t).unwrap();
+ let grad_w = grads.get(&w).unwrap();
+ assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
+ assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
+ assert_eq!(
+ test_utils::to_vec3_round(&grad_t.i(0)?, 2)?,
+ [
+ [
+ [9.29, -7.03, 0.94, 3.49, -7.71],
+ [-1.8, -7.82, 8.9, 8.46, 7.43],
+ [-25.84, 22.09, -19.27, -0.22, 1.69],
+ [4.02, 18.53, -18.37, 2.3, -24.51],
+ [7.72, -9.68, -12.34, 5.6, -20.22]
+ ],
+ [
+ [21.73, 3.39, -18.27, 3.86, -3.65],
+ [8.25, 3.73, 30.73, -8.61, -11.93],
+ [-72.15, -15.36, -17.53, -12.32, -1.61],
+ [-22.32, -7.79, -91.82, 6.44, -37.69],
+ [52.88, 14.44, 42.75, 9.88, 2.01]
+ ],
+ [
+ [-8.98, 9.91, 6.75, -4.68, 15.38],
+ [4.93, -0.33, 9.94, -1.46, 14.78],
+ [13.62, -30.63, 3.96, -3.58, -4.48],
+ [-14.13, 1.19, -34.43, 3.08, -33.83],
+ [17.28, 12.94, 31.83, -3.35, 6.81]
+ ],
+ [
+ [23.54, 6.98, -24.52, 0.52, 4.87],
+ [9.65, 6.18, 1.71, -25.23, -4.93],
+ [-54.99, -23.66, 3.19, -3.73, 18.58],
+ [-21.35, -10.39, -39.88, 28.73, -30.76],
+ [-9.13, 11.12, -14.0, -8.23, -11.25]
+ ]
+ ]
+ );
+ assert_eq!(
+ test_utils::to_vec3_round(&grad_w.i(0)?, 2)?,
+ [
+ [
+ [28.34, -45.75, 7.32],
+ [0.72, -35.28, 19.23],
+ [-28.29, 20.89, -5.18]
+ ],
+ [
+ [-16.04, -16.38, 32.12],
+ [57.5, 25.81, 11.96],
+ [-18.66, 8.48, -9.92]
+ ],
+ [
+ [2.93, 1.57, -23.76],
+ [12.74, -26.2, -17.88],
+ [-14.98, -9.35, 12.2]
+ ],
+ [
+ [-0.18, -6.82, 20.79],
+ [-2.54, 27.11, -10.11],
+ [-0.41, -3.18, -0.07]
+ ]
+ ]
+ );
Ok(())
}
diff --git a/candle-examples/examples/musicgen/encodec_model.rs b/candle-examples/examples/musicgen/encodec_model.rs
index 86e3b6e9..53b252ed 100644
--- a/candle-examples/examples/musicgen/encodec_model.rs
+++ b/candle-examples/examples/musicgen/encodec_model.rs
@@ -278,6 +278,7 @@ impl EncodecConv1d {
padding: 0,
stride,
groups: 1,
+ dilation: 1,
},
vb.pp("conv"),
)?,
@@ -289,6 +290,7 @@ impl EncodecConv1d {
padding: 0,
stride,
groups: 1,
+ dilation: 1,
},
vb.pp("conv"),
)?,
diff --git a/candle-examples/examples/stable-diffusion/resnet.rs b/candle-examples/examples/stable-diffusion/resnet.rs
index 5f6a2558..4cfd386d 100644
--- a/candle-examples/examples/stable-diffusion/resnet.rs
+++ b/candle-examples/examples/stable-diffusion/resnet.rs
@@ -66,6 +66,7 @@ impl ResnetBlock2D {
stride: 1,
padding: 1,
groups: 1,
+ dilation: 1,
};
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
@@ -80,6 +81,7 @@ impl ResnetBlock2D {
stride: 1,
padding: 0,
groups: 1,
+ dilation: 1,
};
Some(conv2d(
in_channels,
diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs
index 9f0c9ef8..d6bea09a 100644
--- a/candle-examples/examples/whisper/model.rs
+++ b/candle-examples/examples/whisper/model.rs
@@ -281,11 +281,13 @@ impl AudioEncoder {
padding: 1,
stride: 1,
groups: 1,
+ dilation: 1,
};
let cfg2 = Conv1dConfig {
padding: 1,
stride: 2,
groups: 1,
+ dilation: 1,
};
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
diff --git a/candle-examples/examples/yolo-v3/darknet.rs b/candle-examples/examples/yolo-v3/darknet.rs
index de8fcf09..0c81bca8 100644
--- a/candle-examples/examples/yolo-v3/darknet.rs
+++ b/candle-examples/examples/yolo-v3/darknet.rs
@@ -132,6 +132,7 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
stride,
padding,
groups: 1,
+ dilation: 1,
};
let conv = if bias {
conv2d(p, filters, size, conv_cfg, vb.pp(&format!("conv_{index}")))?
diff --git a/candle-examples/examples/yolo-v8/model.rs b/candle-examples/examples/yolo-v8/model.rs
index 98a0cb63..d7fe5c12 100644
--- a/candle-examples/examples/yolo-v8/model.rs
+++ b/candle-examples/examples/yolo-v8/model.rs
@@ -93,6 +93,7 @@ impl ConvBlock {
padding,
stride,
groups: 1,
+ dilation: 1,
};
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu
index 5ccce317..c67a4300 100644
--- a/candle-kernels/src/conv.cu
+++ b/candle-kernels/src/conv.cu
@@ -8,6 +8,7 @@ __device__ void conv1d(
const size_t l_out,
const size_t stride,
const size_t padding,
+ const size_t dilation,
const size_t *info,
const T *src,
const T *kernel,
@@ -36,7 +37,7 @@ __device__ void conv1d(
const size_t src_idx0 = b_idx * src_s[0];
A d = 0;
for (size_t offset = 0; offset < k_size; ++offset) {
- size_t src_l = stride * dst_l + offset;
+ size_t src_l = (stride * dst_l + offset) * dilation;
if (src_l < padding || src_l >= padding + l_in) {
continue;
}
@@ -58,6 +59,7 @@ __device__ void conv2d(
const size_t h_out,
const size_t stride,
const size_t padding,
+ const size_t dilation,
const size_t *info,
const T *src,
const T *kernel,
@@ -90,13 +92,13 @@ __device__ void conv2d(
const size_t src_idx0 = b_idx * src_s[0];
A d = 0;
for (size_t w_offset = 0; w_offset < w_k; ++w_offset) {
- size_t src_w = stride * dst_w + w_offset;
+ size_t src_w = (stride * dst_w + w_offset) * dilation;
if (src_w < padding || src_w >= w_in + padding) {
continue;
}
src_w -= padding;
for (size_t h_offset = 0; h_offset < h_k; ++h_offset) {
- size_t src_h = stride * dst_h + h_offset;
+ size_t src_h = (stride * dst_h + h_offset) * dilation;
if (src_h < padding || src_h >= h_in + padding) {
continue;
}
@@ -120,6 +122,7 @@ __device__ void conv_transpose2d(
const size_t stride,
const size_t padding,
const size_t out_padding,
+ const size_t dilation,
const size_t *info,
const T *src,
const T *kernel,
@@ -335,12 +338,13 @@ extern "C" __global__ void FN_NAME( \
const size_t num_dims, \
const size_t stride, \
const size_t padding, \
+ const size_t dilation, \
const size_t *info, \
const TYPENAME *src, \
const TYPENAME *kernel, \
TYPENAME *dst \
) { \
- conv1d<TYPENAME, TYPEACC>(src_numel, num_dims, stride, padding, info, src, kernel, dst); \
+ conv1d<TYPENAME, TYPEACC>(src_numel, num_dims, stride, padding, dilation, info, src, kernel, dst); \
} \
#define CONV2D_OP(TYPENAME, TYPEACC, FN_NAME) \
@@ -350,12 +354,13 @@ extern "C" __global__ void FN_NAME( \
const size_t h_out, \
const size_t stride, \
const size_t padding, \
+ const size_t dilation, \
const size_t *info, \
const TYPENAME *src, \
const TYPENAME *kernel, \
TYPENAME *dst \
) { \
- conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, info, src, kernel, dst); \
+ conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, dilation, info, src, kernel, dst); \
} \
#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
@@ -366,12 +371,13 @@ extern "C" __global__ void FN_NAME( \
const size_t stride, \
const size_t padding, \
const size_t out_padding, \
+ const size_t dilation, \
const size_t *info, \
const TYPENAME *src, \
const TYPENAME *kernel, \
TYPENAME *dst \
) { \
- conv_transpose2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, out_padding, info, src, kernel, dst); \
+ conv_transpose2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, out_padding, dilation, info, src, kernel, dst); \
} \
#define AVG_POOL2D_OP(TYPENAME, TYPEACC, FN_NAME) \
diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs
index e43de8ef..dbf23aa5 100644
--- a/candle-nn/src/conv.rs
+++ b/candle-nn/src/conv.rs
@@ -5,6 +5,7 @@ use candle::{Result, Tensor};
pub struct Conv1dConfig {
pub padding: usize,
pub stride: usize,
+ pub dilation: usize,
pub groups: usize,
}
@@ -13,6 +14,7 @@ impl Default for Conv1dConfig {
Self {
padding: 0,
stride: 1,
+ dilation: 1,
groups: 1,
}
}
@@ -45,6 +47,7 @@ impl crate::Module for Conv1d {
&self.weight,
self.config.padding,
self.config.stride,
+ self.config.dilation,
self.config.groups,
)?;
match &self.bias {
@@ -62,6 +65,7 @@ impl crate::Module for Conv1d {
pub struct Conv2dConfig {
pub padding: usize,
pub stride: usize,
+ pub dilation: usize,
pub groups: usize,
}
@@ -70,6 +74,7 @@ impl Default for Conv2dConfig {
Self {
padding: 0,
stride: 1,
+ dilation: 1,
groups: 1,
}
}
@@ -103,6 +108,7 @@ impl crate::Module for Conv2d {
&self.weight,
self.config.padding,
self.config.stride,
+ self.config.dilation,
self.config.groups,
)?;
match &self.bias {
diff --git a/candle-wasm-examples/whisper/src/model.rs b/candle-wasm-examples/whisper/src/model.rs
index 72dbdcdd..239ceee5 100644
--- a/candle-wasm-examples/whisper/src/model.rs
+++ b/candle-wasm-examples/whisper/src/model.rs
@@ -269,11 +269,13 @@ impl AudioEncoder {
padding: 1,
stride: 1,
groups: 1,
+ dilation: 1,
};
let cfg2 = Conv1dConfig {
padding: 1,
stride: 2,
groups: 1,
+ dilation: 1,
};
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
diff --git a/candle-wasm-examples/yolo/src/model.rs b/candle-wasm-examples/yolo/src/model.rs
index a63c6e94..e0fa7ac4 100644
--- a/candle-wasm-examples/yolo/src/model.rs
+++ b/candle-wasm-examples/yolo/src/model.rs
@@ -97,6 +97,7 @@ impl ConvBlock {
padding,
stride,
groups: 1,
+ dilation: 1,
};
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;