summaryrefslogtreecommitdiff
path: root/candle-core/src/tensor.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r--candle-core/src/tensor.rs19
1 files changed, 16 insertions, 3 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 5d4e106f..f9a6ebb5 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -490,7 +490,7 @@ impl Tensor {
if dim >= self.dims().len() {
Err(Error::DimOutOfRange {
shape: self.shape().clone(),
- dim,
+ dim: dim as i32,
op,
})?
} else {
@@ -509,6 +509,7 @@ impl Tensor {
dim,
start,
len,
+ msg: "start + len > dim_len",
})?
}
if start == 0 && dims[dim] == len {
@@ -576,10 +577,22 @@ impl Tensor {
let (b_size, c_in, l_in) = match *self.dims() {
[b_size, c_in, l_in] => (Some(b_size), c_in, l_in),
[c_in, l_in] => (None, c_in, l_in),
- _ => todo!("proper error message"),
+ _ => Err(Error::Conv1dInvalidArgs {
+ inp_shape: self.shape().clone(),
+ k_shape: kernel.shape().clone(),
+ padding,
+ stride,
+ msg: "input rank is not 2 or 3",
+ })?,
};
if c_in != c_in_k {
- todo!("proper error message")
+ Err(Error::Conv1dInvalidArgs {
+ inp_shape: self.shape().clone(),
+ k_shape: kernel.shape().clone(),
+ padding,
+ stride,
+ msg: "the number of in-channels on the input doesn't match the kernel size",
+ })?
}
let params = crate::conv::ParamsConv1D {
b_size,