diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-12-23 11:19:22 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-23 11:19:22 +0100 |
commit | ba1fae590ee059d05ea421ec8072bd3e07ba9d40 (patch) | |
tree | b50d7beba0fa3cf5b537603c4be2bc973df58ed2 /candle-core/src/tensor.rs | |
parent | 78d982e1bdee80e0a246326b91e9a2aa552ec0fa (diff) | |
download | candle-ba1fae590ee059d05ea421ec8072bd3e07ba9d40.tar.gz candle-ba1fae590ee059d05ea421ec8072bd3e07ba9d40.tar.bz2 candle-ba1fae590ee059d05ea421ec8072bd3e07ba9d40.zip |
Validate the kernel size in pooling ops. (#1473)
* Validate the kernel size in pooling ops.
* Revert the changes to basics.
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 28 |
1 files changed, 16 insertions, 12 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index f15f8c1c..54f9fa2b 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -396,7 +396,7 @@ impl Tensor { device: &Device, ) -> Result<Self> { if D::is_zero(&step) { - crate::bail!("step cannot be zero") + bail!("step cannot be zero") } let mut data = vec![]; let mut current = start; @@ -1041,6 +1041,9 @@ impl Tensor { let kernel_size = kernel_size.to_usize2(); let stride = stride.to_usize2(); let (n, c, h, w) = self.dims4()?; + if h < kernel_size.0 || w < kernel_size.1 { + bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}") + } // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d let h_out = (h - kernel_size.0) / stride.0 + 1; let w_out = (w - kernel_size.1) / stride.1 + 1; @@ -1076,6 +1079,9 @@ impl Tensor { let kernel_size = kernel_size.to_usize2(); let stride = stride.to_usize2(); let (n, c, h, w) = self.dims4()?; + if h < kernel_size.0 || w < kernel_size.1 { + bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}") + } // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d let h_out = (h - kernel_size.0) / stride.0 + 1; let w_out = (w - kernel_size.1) / stride.1 + 1; @@ -1798,7 +1804,7 @@ impl Tensor { let is_permutation = dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i)); if !is_permutation { - crate::bail!( + bail!( "dimension mismatch in permute, tensor {:?}, dims: {:?}", self.dims(), dims @@ -2293,7 +2299,7 @@ impl Tensor { if left == 0 && right == 0 { Ok(self.clone()) } else if self.elem_count() == 0 { - crate::bail!("cannot use pad_with_same on an empty tensor") + bail!("cannot use pad_with_same on an empty tensor") } else if left == 0 { let dim = dim.to_index(self.shape(), "pad_with_same")?; let r = self.narrow(dim, self.dim(dim)? - 1, 1)?; @@ -2457,13 +2463,13 @@ impl Tensor { pub fn normalize_axis(&self, axis: i64) -> Result<usize> { let rank = self.rank() as i64; if rank <= axis { - crate::bail!("axis {axis} is too large, tensor rank {rank}") + bail!("axis {axis} is too large, tensor rank {rank}") } else if 0 <= axis { Ok(axis as usize) } else { let naxis = rank + axis; if naxis < 0 { - crate::bail!("axis {axis} is too small, tensor rank {rank}") + bail!("axis {axis} is too small, tensor rank {rank}") } Ok(naxis as usize) } @@ -2525,14 +2531,14 @@ impl Tensor { let src_dims = src.dims(); let self_dims = self.dims(); if self_dims.len() != src_dims.len() { - crate::bail!( + bail!( "slice-assign requires input with the same rank {} <> {}", self_dims.len(), src_dims.len() ) } if self_dims.len() != ranges.len() { - crate::bail!( + bail!( "slice-assign requires input with the same rank as there are ranges {} <> {}", self_dims.len(), ranges.len() @@ -2552,18 +2558,16 @@ impl Tensor { std::ops::Bound::Excluded(v) => *v, }; if end_excluded <= start_included { - crate::bail!( - "slice-assign: empty range for dim {i}, {start_included} {end_excluded}" - ) + bail!("slice-assign: empty range for dim {i}, {start_included} {end_excluded}") } if self_dims[i] < end_excluded { - crate::bail!( + bail!( "slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}", self_dims[i] ) } if end_excluded - start_included != src_dims[i] { - crate::bail!( + bail!( "slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i] ) } |