summaryrefslogtreecommitdiff
path: root/candle-core/src/conv.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/conv.rs')
-rw-r--r--candle-core/src/conv.rs27
1 files changed, 17 insertions, 10 deletions
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs
index d9e0a9ab..1f3ef582 100644
--- a/candle-core/src/conv.rs
+++ b/candle-core/src/conv.rs
@@ -11,12 +11,12 @@ pub struct ParamsConv1D {
pub(crate) k_size: usize,
pub(crate) padding: usize,
pub(crate) stride: usize,
+ pub(crate) dilation: usize,
}
impl ParamsConv1D {
pub(crate) fn l_out(&self) -> usize {
- let dilation = 1;
- (self.l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1
+ (self.l_in + 2 * self.padding - self.dilation * (self.k_size - 1) - 1) / self.stride + 1
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
@@ -36,17 +36,16 @@ pub struct ParamsConv2D {
pub(crate) c_in: usize,
pub(crate) padding: usize,
pub(crate) stride: usize,
+ pub(crate) dilation: usize,
}
impl ParamsConv2D {
pub(crate) fn out_h(&self) -> usize {
- let dilation = 1;
- (self.i_h + 2 * self.padding - dilation * (self.k_h - 1) - 1) / self.stride + 1
+ (self.i_h + 2 * self.padding - self.dilation * (self.k_h - 1) - 1) / self.stride + 1
}
pub(crate) fn out_w(&self) -> usize {
- let dilation = 1;
- (self.i_w + 2 * self.padding - dilation * (self.k_w - 1) - 1) / self.stride + 1
+ (self.i_w + 2 * self.padding - self.dilation * (self.k_w - 1) - 1) / self.stride + 1
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
@@ -66,18 +65,17 @@ pub struct ParamsConvTranspose2D {
pub(crate) padding: usize,
pub(crate) output_padding: usize,
pub(crate) stride: usize,
+ pub(crate) dilation: usize,
}
impl ParamsConvTranspose2D {
pub(crate) fn out_h(&self) -> usize {
- let dilation = 1;
- (self.i_h - 1) * self.stride + dilation * (self.k_h - 1) + self.output_padding + 1
+ (self.i_h - 1) * self.stride + self.dilation * (self.k_h - 1) + self.output_padding + 1
- 2 * self.padding
}
pub(crate) fn out_w(&self) -> usize {
- let dilation = 1;
- (self.i_w - 1) * self.stride + dilation * (self.k_w - 1) + self.output_padding + 1
+ (self.i_w - 1) * self.stride + self.dilation * (self.k_w - 1) + self.output_padding + 1
- 2 * self.padding
}
@@ -96,6 +94,7 @@ impl Tensor {
kernel,
padding: params.padding,
stride: params.stride,
+ dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
@@ -107,6 +106,7 @@ impl Tensor {
kernel: &Self,
padding: usize,
stride: usize,
+ dilation: usize,
groups: usize,
) -> Result<Self> {
let (c_out, c_in_k, k_size) = kernel.dims3()?;
@@ -130,6 +130,7 @@ impl Tensor {
k_size,
padding,
stride,
+ dilation,
};
if groups == 1 {
self.conv1d_single_group(kernel, &params)
@@ -154,6 +155,7 @@ impl Tensor {
kernel,
padding: params.padding,
stride: params.stride,
+ dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
@@ -165,6 +167,7 @@ impl Tensor {
kernel: &Self,
padding: usize,
stride: usize,
+ dilation: usize,
groups: usize,
) -> Result<Self> {
let (b_size, c_in, i_h, i_w) = self.dims4()?;
@@ -184,6 +187,7 @@ impl Tensor {
c_in: c_in / groups,
padding,
stride,
+ dilation,
};
if groups == 1 {
self.conv2d_single_group(kernel, &params)
@@ -206,6 +210,7 @@ impl Tensor {
padding: usize,
output_padding: usize,
stride: usize,
+ dilation: usize,
) -> Result<Self> {
let (b_size, c_in, i_h, i_w) = self.dims4()?;
let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?;
@@ -223,6 +228,7 @@ impl Tensor {
padding,
output_padding,
stride,
+ dilation,
};
let storage = self.storage().conv_transpose2d(
self.layout(),
@@ -236,6 +242,7 @@ impl Tensor {
padding: params.padding,
output_padding: params.output_padding,
stride: params.stride,
+ dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))