diff options
Diffstat (limited to 'candle-core/src/conv.rs')
-rw-r--r-- | candle-core/src/conv.rs | 27 |
1 files changed, 17 insertions, 10 deletions
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index d9e0a9ab..1f3ef582 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -11,12 +11,12 @@ pub struct ParamsConv1D { pub(crate) k_size: usize, pub(crate) padding: usize, pub(crate) stride: usize, + pub(crate) dilation: usize, } impl ParamsConv1D { pub(crate) fn l_out(&self) -> usize { - let dilation = 1; - (self.l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1 + (self.l_in + 2 * self.padding - self.dilation * (self.k_size - 1) - 1) / self.stride + 1 } pub(crate) fn out_dims(&self) -> Vec<usize> { @@ -36,17 +36,16 @@ pub struct ParamsConv2D { pub(crate) c_in: usize, pub(crate) padding: usize, pub(crate) stride: usize, + pub(crate) dilation: usize, } impl ParamsConv2D { pub(crate) fn out_h(&self) -> usize { - let dilation = 1; - (self.i_h + 2 * self.padding - dilation * (self.k_h - 1) - 1) / self.stride + 1 + (self.i_h + 2 * self.padding - self.dilation * (self.k_h - 1) - 1) / self.stride + 1 } pub(crate) fn out_w(&self) -> usize { - let dilation = 1; - (self.i_w + 2 * self.padding - dilation * (self.k_w - 1) - 1) / self.stride + 1 + (self.i_w + 2 * self.padding - self.dilation * (self.k_w - 1) - 1) / self.stride + 1 } pub(crate) fn out_dims(&self) -> Vec<usize> { @@ -66,18 +65,17 @@ pub struct ParamsConvTranspose2D { pub(crate) padding: usize, pub(crate) output_padding: usize, pub(crate) stride: usize, + pub(crate) dilation: usize, } impl ParamsConvTranspose2D { pub(crate) fn out_h(&self) -> usize { - let dilation = 1; - (self.i_h - 1) * self.stride + dilation * (self.k_h - 1) + self.output_padding + 1 + (self.i_h - 1) * self.stride + self.dilation * (self.k_h - 1) + self.output_padding + 1 - 2 * self.padding } pub(crate) fn out_w(&self) -> usize { - let dilation = 1; - (self.i_w - 1) * self.stride + dilation * (self.k_w - 1) + self.output_padding + 1 + (self.i_w - 1) * self.stride + self.dilation * (self.k_w - 1) + self.output_padding + 1 - 2 * self.padding } @@ -96,6 +94,7 @@ impl Tensor { kernel, padding: params.padding, stride: params.stride, + dilation: params.dilation, }); let out_dims = params.out_dims(); Ok(crate::tensor::from_storage(storage, out_dims, op, false)) @@ -107,6 +106,7 @@ impl Tensor { kernel: &Self, padding: usize, stride: usize, + dilation: usize, groups: usize, ) -> Result<Self> { let (c_out, c_in_k, k_size) = kernel.dims3()?; @@ -130,6 +130,7 @@ impl Tensor { k_size, padding, stride, + dilation, }; if groups == 1 { self.conv1d_single_group(kernel, ¶ms) @@ -154,6 +155,7 @@ impl Tensor { kernel, padding: params.padding, stride: params.stride, + dilation: params.dilation, }); let out_dims = params.out_dims(); Ok(crate::tensor::from_storage(storage, out_dims, op, false)) @@ -165,6 +167,7 @@ impl Tensor { kernel: &Self, padding: usize, stride: usize, + dilation: usize, groups: usize, ) -> Result<Self> { let (b_size, c_in, i_h, i_w) = self.dims4()?; @@ -184,6 +187,7 @@ impl Tensor { c_in: c_in / groups, padding, stride, + dilation, }; if groups == 1 { self.conv2d_single_group(kernel, ¶ms) @@ -206,6 +210,7 @@ impl Tensor { padding: usize, output_padding: usize, stride: usize, + dilation: usize, ) -> Result<Self> { let (b_size, c_in, i_h, i_w) = self.dims4()?; let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?; @@ -223,6 +228,7 @@ impl Tensor { padding, output_padding, stride, + dilation, }; let storage = self.storage().conv_transpose2d( self.layout(), @@ -236,6 +242,7 @@ impl Tensor { padding: params.padding, output_padding: params.output_padding, stride: params.stride, + dilation: params.dilation, }); let out_dims = params.out_dims(); Ok(crate::tensor::from_storage(storage, out_dims, op, false)) |