summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-12 09:17:08 +0100
committerGitHub <noreply@github.com>2023-07-12 09:17:08 +0100
commita76ec797da866d92a39ca41a2c9e70de6d6d7df7 (patch)
tree56021e6e33aebaca580a72331b32cd5ae1145fe1
parentfa760759e5fa94c8486566af6dd3a456d0548221 (diff)
downloadcandle-a76ec797da866d92a39ca41a2c9e70de6d6d7df7.tar.gz
candle-a76ec797da866d92a39ca41a2c9e70de6d6d7df7.tar.bz2
candle-a76ec797da866d92a39ca41a2c9e70de6d6d7df7.zip
Cleanup the main crate error and add a couple dedicated ones (#142)
* Cosmetic cleanups to the error enum. * More error cleanup. * Proper error handling rather than panicing. * Add some conv1d dedicated error.
-rw-r--r--candle-core/src/cpu_backend.rs43
-rw-r--r--candle-core/src/error.rs135
-rw-r--r--candle-core/src/layout.rs20
-rw-r--r--candle-core/src/shape.rs34
-rw-r--r--candle-core/src/tensor.rs19
-rw-r--r--candle-nn/src/var_builder.rs5
6 files changed, 144 insertions, 112 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index de32b549..a20d032d 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -261,6 +261,17 @@ impl<'a> Map2 for Conv1D<'a> {
struct MatMul((usize, usize, usize, usize));
+impl MatMul {
+ fn striding_error(&self, lhs_l: &Layout, rhs_l: &Layout, msg: &'static str) -> Error {
+ Error::MatMulUnexpectedStriding(Box::new(crate::error::MatMulUnexpectedStriding {
+ lhs_l: lhs_l.clone(),
+ rhs_l: rhs_l.clone(),
+ bmnk: self.0,
+ msg,
+ }))
+ }
+}
+
impl Map2 for MatMul {
const OP: &'static str = "mat_mul";
@@ -290,19 +301,13 @@ impl Map2 for MatMul {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[stride] => stride,
[] => m * k,
- _ => Err(Error::UnexpectedStriding {
- lhs_stride: lhs_stride.to_vec(),
- rhs_stride: rhs_stride.to_vec(),
- })?,
+ _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
};
let b_skip: usize = match rhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[stride] => stride,
[] => n * k,
- _ => Err(Error::UnexpectedStriding {
- lhs_stride: lhs_stride.to_vec(),
- rhs_stride: rhs_stride.to_vec(),
- })?,
+ _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
};
let c_skip: usize = m * n;
@@ -369,19 +374,13 @@ impl Map2 for MatMul {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[stride] => stride,
[] => m * k,
- _ => Err(Error::UnexpectedStriding {
- lhs_stride: lhs_stride.to_vec(),
- rhs_stride: rhs_stride.to_vec(),
- })?,
+ _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
};
let b_skip: usize = match rhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[stride] => stride,
[] => n * k,
- _ => Err(Error::UnexpectedStriding {
- lhs_stride: lhs_stride.to_vec(),
- rhs_stride: rhs_stride.to_vec(),
- })?,
+ _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
};
let c_skip: usize = m * n;
@@ -395,11 +394,7 @@ impl Map2 for MatMul {
} else if rhs_m1 == k && rhs_m2 == 1 {
(k as i32, b'T')
} else {
- Err(Error::MatMulNonContiguous {
- lhs_stride: lhs_stride.to_vec(),
- rhs_stride: rhs_stride.to_vec(),
- mnk: (m, n, k),
- })?
+ Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
};
// The b tensor has dims batching, m, k (lhs)
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
@@ -407,11 +402,7 @@ impl Map2 for MatMul {
} else if lhs_m1 == m && lhs_m2 == 1 {
(m as i32, b'T')
} else {
- Err(Error::MatMulNonContiguous {
- lhs_stride: lhs_stride.to_vec(),
- rhs_stride: rhs_stride.to_vec(),
- mnk: (m, n, k),
- })?
+ Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
};
let mut dst = vec![T::zero(); b * m * n];
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs
index 27fd11bb..b9131356 100644
--- a/candle-core/src/error.rs
+++ b/candle-core/src/error.rs
@@ -1,8 +1,17 @@
-use crate::{DType, DeviceLocation, Shape};
+use crate::{DType, DeviceLocation, Layout, Shape};
+
+#[derive(Debug, Clone)]
+pub struct MatMulUnexpectedStriding {
+ pub lhs_l: Layout,
+ pub rhs_l: Layout,
+ pub bmnk: (usize, usize, usize, usize),
+ pub msg: &'static str,
+}
/// Main library error type.
#[derive(thiserror::Error, Debug)]
pub enum Error {
+ // === DType Errors ===
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
UnexpectedDType {
msg: &'static str,
@@ -10,47 +19,39 @@ pub enum Error {
got: DType,
},
- #[error("{msg}, expected: {expected:?}, got: {got:?}")]
- UnexpectedShape {
- msg: String,
- expected: Shape,
- got: Shape,
+ #[error("dtype mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
+ DTypeMismatchBinaryOp {
+ lhs: DType,
+ rhs: DType,
+ op: &'static str,
},
+ #[error("unsupported dtype {0:?} for op {1}")]
+ UnsupportedDTypeForOp(DType, &'static str),
+
+ // === Dimension Index Errors ===
#[error("{op}: dimension index {dim} out of range for {shape:?}")]
DimOutOfRange {
shape: Shape,
- dim: usize,
+ dim: i32,
op: &'static str,
},
- #[error("invalid args for narrow: {shape:?}, dim: {dim}, start: {start}, len:{len}")]
- NarrowInvalidArgs {
+ // === Shape Errors ===
+ #[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")]
+ UnexpectedNumberOfDims {
+ expected: usize,
+ got: usize,
shape: Shape,
- dim: usize,
- start: usize,
- len: usize,
},
- #[error("{op} only supports contiguous tensors")]
- RequiresContiguous { op: &'static str },
-
- #[error("{op} expects at least one tensor")]
- OpRequiresAtLeastOneTensor { op: &'static str },
-
- #[error("backward is not supported for {op}")]
- BackwardNotSupported { op: &'static str },
-
- #[error("{op} invalid index {index} with vocab {vocab_size}")]
- InvalidIndex {
- op: &'static str,
- index: usize,
- vocab_size: usize,
+ #[error("{msg}, expected: {expected:?}, got: {got:?}")]
+ UnexpectedShape {
+ msg: String,
+ expected: Shape,
+ got: Shape,
},
- #[error("the candle crate has not been built with cuda support")]
- NotCompiledWithCudaSupport,
-
#[error(
"Shape mismatch, got buffer of size {buffer_size} which is compatible with shape {shape:?}"
)]
@@ -71,6 +72,7 @@ pub enum Error {
nth_shape: Shape,
},
+ // === Device Errors ===
#[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
DeviceMismatchBinaryOp {
lhs: DeviceLocation,
@@ -78,27 +80,56 @@ pub enum Error {
op: &'static str,
},
- #[error("dtype mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
- DTypeMismatchBinaryOp {
- lhs: DType,
- rhs: DType,
- op: &'static str,
+ // === Op Specific Errors ===
+ #[error("narrow invalid args {msg}: {shape:?}, dim: {dim}, start: {start}, len:{len}")]
+ NarrowInvalidArgs {
+ shape: Shape,
+ dim: usize,
+ start: usize,
+ len: usize,
+ msg: &'static str,
},
- #[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")]
- UnexpectedNumberOfDims {
- expected: usize,
- got: usize,
- shape: Shape,
+ #[error("conv1d invalid args {msg}: inp: {inp_shape:?}, k: {k_shape:?}, pad: {padding}, stride: {stride}")]
+ Conv1dInvalidArgs {
+ inp_shape: Shape,
+ k_shape: Shape,
+ padding: usize,
+ stride: usize,
+ msg: &'static str,
},
- // TODO this is temporary when we support arbitrary matmul
- #[error("temporary error where matmul doesn't support arbitrary striding {lhs_stride:?} x {rhs_stride:?}")]
- UnexpectedStriding {
- lhs_stride: Vec<usize>,
- rhs_stride: Vec<usize>,
+ #[error("{op} invalid index {index} with vocab {vocab_size}")]
+ InvalidIndex {
+ op: &'static str,
+ index: usize,
+ vocab_size: usize,
},
+ #[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
+ BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
+
+ // Box indirection to avoid large variant.
+ #[error("{0:?}")]
+ MatMulUnexpectedStriding(Box<MatMulUnexpectedStriding>),
+
+ #[error("{op} only supports contiguous tensors")]
+ RequiresContiguous { op: &'static str },
+
+ #[error("{op} expects at least one tensor")]
+ OpRequiresAtLeastOneTensor { op: &'static str },
+
+ #[error("backward is not supported for {op}")]
+ BackwardNotSupported { op: &'static str },
+
+ // === Other Errors ===
+ #[error("the candle crate has not been built with cuda support")]
+ NotCompiledWithCudaSupport,
+
+ #[error("cannot find tensor {path}")]
+ CannotFindTensor { path: String },
+
+ // === Wrapped Errors ===
#[error(transparent)]
Cuda(Box<dyn std::error::Error + Send + Sync>),
@@ -126,22 +157,6 @@ pub enum Error {
#[error("unsupported safetensor dtype {0:?}")]
UnsupportedSafeTensorDtype(safetensors::Dtype),
-
- #[error("unsupported dtype {0:?} for op {1}")]
- UnsupportedDTypeForOp(DType, &'static str),
-
- #[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
- BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
-
- #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
- MatMulNonContiguous {
- lhs_stride: Vec<usize>,
- rhs_stride: Vec<usize>,
- mnk: (usize, usize, usize),
- },
-
- #[error("cannot find tensor {path}")]
- CannotFindTensor { path: String },
}
pub type Result<T> = std::result::Result<T, Error>;
diff --git a/candle-core/src/layout.rs b/candle-core/src/layout.rs
index 3f629d50..79d40cfc 100644
--- a/candle-core/src/layout.rs
+++ b/candle-core/src/layout.rs
@@ -60,20 +60,26 @@ impl Layout {
self.shape.is_fortran_contiguous(&self.stride)
}
- pub(crate) fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
+ pub(crate) fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
let dims = self.shape().dims();
if dim >= dims.len() {
- Err(Error::UnexpectedNumberOfDims {
- expected: dim + 1,
- got: dims.len(),
+ Err(Error::DimOutOfRange {
shape: self.shape().clone(),
+ dim: dim as i32,
+ op: "narrow",
})?
}
- if start + length > dims[dim] {
- todo!("add a proper error: out of bounds for narrow {dim} {start} {length} {dims:?}")
+ if start + len > dims[dim] {
+ Err(Error::NarrowInvalidArgs {
+ shape: self.shape.clone(),
+ dim,
+ start,
+ len,
+ msg: "start + len > dim_len",
+ })?
}
let mut dims = dims.to_vec();
- dims[dim] = length;
+ dims[dim] = len;
Ok(Self {
shape: Shape::from(dims),
stride: self.stride.clone(),
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs
index 632ef116..b5e64454 100644
--- a/candle-core/src/shape.rs
+++ b/candle-core/src/shape.rs
@@ -194,7 +194,7 @@ impl Dim for usize {
if dim >= shape.dims().len() {
Err(Error::DimOutOfRange {
shape: shape.clone(),
- dim,
+ dim: dim as i32,
op,
})?
} else {
@@ -207,7 +207,7 @@ impl Dim for usize {
if dim > shape.dims().len() {
Err(Error::DimOutOfRange {
shape: shape.clone(),
- dim,
+ dim: dim as i32,
op,
})?
} else {
@@ -221,30 +221,36 @@ pub enum D {
Minus2,
}
+impl D {
+ fn out_of_range(&self, shape: &Shape, op: &'static str) -> Error {
+ let dim = match self {
+ Self::Minus1 => -1,
+ Self::Minus2 => -2,
+ };
+ Error::DimOutOfRange {
+ shape: shape.clone(),
+ dim,
+ op,
+ }
+ }
+}
+
impl Dim for D {
fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
let rank = shape.rank();
match self {
Self::Minus1 if rank >= 1 => Ok(rank - 1),
Self::Minus2 if rank >= 2 => Ok(rank - 2),
- _ => Err(Error::DimOutOfRange {
- shape: shape.clone(),
- dim: 42, // TODO: Have an adequate error
- op,
- }),
+ _ => Err(self.out_of_range(shape, op)),
}
}
fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
let rank = shape.rank();
match self {
- Self::Minus1 if rank >= 1 => Ok(rank),
- Self::Minus2 if rank >= 2 => Ok(rank - 1),
- _ => Err(Error::DimOutOfRange {
- shape: shape.clone(),
- dim: 42, // TODO: Have an adequate error
- op,
- }),
+ Self::Minus1 => Ok(rank),
+ Self::Minus2 if rank >= 1 => Ok(rank - 1),
+ _ => Err(self.out_of_range(shape, op)),
}
}
}
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 5d4e106f..f9a6ebb5 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -490,7 +490,7 @@ impl Tensor {
if dim >= self.dims().len() {
Err(Error::DimOutOfRange {
shape: self.shape().clone(),
- dim,
+ dim: dim as i32,
op,
})?
} else {
@@ -509,6 +509,7 @@ impl Tensor {
dim,
start,
len,
+ msg: "start + len > dim_len",
})?
}
if start == 0 && dims[dim] == len {
@@ -576,10 +577,22 @@ impl Tensor {
let (b_size, c_in, l_in) = match *self.dims() {
[b_size, c_in, l_in] => (Some(b_size), c_in, l_in),
[c_in, l_in] => (None, c_in, l_in),
- _ => todo!("proper error message"),
+ _ => Err(Error::Conv1dInvalidArgs {
+ inp_shape: self.shape().clone(),
+ k_shape: kernel.shape().clone(),
+ padding,
+ stride,
+ msg: "input rank is not 2 or 3",
+ })?,
};
if c_in != c_in_k {
- todo!("proper error message")
+ Err(Error::Conv1dInvalidArgs {
+ inp_shape: self.shape().clone(),
+ k_shape: kernel.shape().clone(),
+ padding,
+ stride,
+ msg: "the number of in-channels on the input doesn't match the kernel size",
+ })?
}
let params = crate::conv::ParamsConv1D {
b_size,
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index 7f68ae08..aa2ec401 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -157,8 +157,9 @@ impl<'a> VarBuilder<'a> {
routing,
safetensors,
} => {
- // Unwrap or 0 just to let the proper error flow.
- let index = routing.get(&path).unwrap_or(&0);
+ let index = routing.get(&path).ok_or_else(|| Error::CannotFindTensor {
+ path: path.to_string(),
+ })?;
safetensors[*index]
.tensor(&path, &data.device)?
.to_dtype(data.dtype)?