summaryrefslogtreecommitdiff
path: root/candle-core/src/tensor.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-12-23 11:19:22 +0100
committerGitHub <noreply@github.com>2023-12-23 11:19:22 +0100
commitba1fae590ee059d05ea421ec8072bd3e07ba9d40 (patch)
treeb50d7beba0fa3cf5b537603c4be2bc973df58ed2 /candle-core/src/tensor.rs
parent78d982e1bdee80e0a246326b91e9a2aa552ec0fa (diff)
downloadcandle-ba1fae590ee059d05ea421ec8072bd3e07ba9d40.tar.gz
candle-ba1fae590ee059d05ea421ec8072bd3e07ba9d40.tar.bz2
candle-ba1fae590ee059d05ea421ec8072bd3e07ba9d40.zip
Validate the kernel size in pooling ops. (#1473)
* Validate the kernel size in pooling ops. * Revert the changes to basics.
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r--candle-core/src/tensor.rs28
1 files changed, 16 insertions, 12 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index f15f8c1c..54f9fa2b 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -396,7 +396,7 @@ impl Tensor {
device: &Device,
) -> Result<Self> {
if D::is_zero(&step) {
- crate::bail!("step cannot be zero")
+ bail!("step cannot be zero")
}
let mut data = vec![];
let mut current = start;
@@ -1041,6 +1041,9 @@ impl Tensor {
let kernel_size = kernel_size.to_usize2();
let stride = stride.to_usize2();
let (n, c, h, w) = self.dims4()?;
+ if h < kernel_size.0 || w < kernel_size.1 {
+ bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
+ }
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
let h_out = (h - kernel_size.0) / stride.0 + 1;
let w_out = (w - kernel_size.1) / stride.1 + 1;
@@ -1076,6 +1079,9 @@ impl Tensor {
let kernel_size = kernel_size.to_usize2();
let stride = stride.to_usize2();
let (n, c, h, w) = self.dims4()?;
+ if h < kernel_size.0 || w < kernel_size.1 {
+ bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
+ }
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
let h_out = (h - kernel_size.0) / stride.0 + 1;
let w_out = (w - kernel_size.1) / stride.1 + 1;
@@ -1798,7 +1804,7 @@ impl Tensor {
let is_permutation =
dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
if !is_permutation {
- crate::bail!(
+ bail!(
"dimension mismatch in permute, tensor {:?}, dims: {:?}",
self.dims(),
dims
@@ -2293,7 +2299,7 @@ impl Tensor {
if left == 0 && right == 0 {
Ok(self.clone())
} else if self.elem_count() == 0 {
- crate::bail!("cannot use pad_with_same on an empty tensor")
+ bail!("cannot use pad_with_same on an empty tensor")
} else if left == 0 {
let dim = dim.to_index(self.shape(), "pad_with_same")?;
let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
@@ -2457,13 +2463,13 @@ impl Tensor {
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
let rank = self.rank() as i64;
if rank <= axis {
- crate::bail!("axis {axis} is too large, tensor rank {rank}")
+ bail!("axis {axis} is too large, tensor rank {rank}")
} else if 0 <= axis {
Ok(axis as usize)
} else {
let naxis = rank + axis;
if naxis < 0 {
- crate::bail!("axis {axis} is too small, tensor rank {rank}")
+ bail!("axis {axis} is too small, tensor rank {rank}")
}
Ok(naxis as usize)
}
@@ -2525,14 +2531,14 @@ impl Tensor {
let src_dims = src.dims();
let self_dims = self.dims();
if self_dims.len() != src_dims.len() {
- crate::bail!(
+ bail!(
"slice-assign requires input with the same rank {} <> {}",
self_dims.len(),
src_dims.len()
)
}
if self_dims.len() != ranges.len() {
- crate::bail!(
+ bail!(
"slice-assign requires input with the same rank as there are ranges {} <> {}",
self_dims.len(),
ranges.len()
@@ -2552,18 +2558,16 @@ impl Tensor {
std::ops::Bound::Excluded(v) => *v,
};
if end_excluded <= start_included {
- crate::bail!(
- "slice-assign: empty range for dim {i}, {start_included} {end_excluded}"
- )
+ bail!("slice-assign: empty range for dim {i}, {start_included} {end_excluded}")
}
if self_dims[i] < end_excluded {
- crate::bail!(
+ bail!(
"slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
self_dims[i]
)
}
if end_excluded - start_included != src_dims[i] {
- crate::bail!(
+ bail!(
"slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
)
}