diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-12 09:17:08 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-12 09:17:08 +0100 |
commit | a76ec797da866d92a39ca41a2c9e70de6d6d7df7 (patch) | |
tree | 56021e6e33aebaca580a72331b32cd5ae1145fe1 | |
parent | fa760759e5fa94c8486566af6dd3a456d0548221 (diff) | |
download | candle-a76ec797da866d92a39ca41a2c9e70de6d6d7df7.tar.gz candle-a76ec797da866d92a39ca41a2c9e70de6d6d7df7.tar.bz2 candle-a76ec797da866d92a39ca41a2c9e70de6d6d7df7.zip |
Cleanup the main crate error and add a couple dedicated ones (#142)
* Cosmetic cleanups to the error enum.
* More error cleanup.
* Proper error handling rather than panicing.
* Add some conv1d dedicated error.
-rw-r--r-- | candle-core/src/cpu_backend.rs | 43 | ||||
-rw-r--r-- | candle-core/src/error.rs | 135 | ||||
-rw-r--r-- | candle-core/src/layout.rs | 20 | ||||
-rw-r--r-- | candle-core/src/shape.rs | 34 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 19 | ||||
-rw-r--r-- | candle-nn/src/var_builder.rs | 5 |
6 files changed, 144 insertions, 112 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index de32b549..a20d032d 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -261,6 +261,17 @@ impl<'a> Map2 for Conv1D<'a> { struct MatMul((usize, usize, usize, usize)); +impl MatMul { + fn striding_error(&self, lhs_l: &Layout, rhs_l: &Layout, msg: &'static str) -> Error { + Error::MatMulUnexpectedStriding(Box::new(crate::error::MatMulUnexpectedStriding { + lhs_l: lhs_l.clone(), + rhs_l: rhs_l.clone(), + bmnk: self.0, + msg, + })) + } +} + impl Map2 for MatMul { const OP: &'static str = "mat_mul"; @@ -290,19 +301,13 @@ impl Map2 for MatMul { [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride, [stride] => stride, [] => m * k, - _ => Err(Error::UnexpectedStriding { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - })?, + _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?, }; let b_skip: usize = match rhs_stride[..rank - 2] { [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride, [stride] => stride, [] => n * k, - _ => Err(Error::UnexpectedStriding { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - })?, + _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?, }; let c_skip: usize = m * n; @@ -369,19 +374,13 @@ impl Map2 for MatMul { [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride, [stride] => stride, [] => m * k, - _ => Err(Error::UnexpectedStriding { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - })?, + _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?, }; let b_skip: usize = match rhs_stride[..rank - 2] { [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride, [stride] => stride, [] => n * k, - _ => Err(Error::UnexpectedStriding { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - })?, + _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?, }; let c_skip: usize = m * n; @@ -395,11 +394,7 @@ impl Map2 for MatMul { } else if rhs_m1 == k && rhs_m2 == 1 { (k as i32, b'T') } else { - Err(Error::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - mnk: (m, n, k), - })? + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))? }; // The b tensor has dims batching, m, k (lhs) let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k { @@ -407,11 +402,7 @@ impl Map2 for MatMul { } else if lhs_m1 == m && lhs_m2 == 1 { (m as i32, b'T') } else { - Err(Error::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - mnk: (m, n, k), - })? + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))? }; let mut dst = vec![T::zero(); b * m * n]; diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 27fd11bb..b9131356 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -1,8 +1,17 @@ -use crate::{DType, DeviceLocation, Shape}; +use crate::{DType, DeviceLocation, Layout, Shape}; + +#[derive(Debug, Clone)] +pub struct MatMulUnexpectedStriding { + pub lhs_l: Layout, + pub rhs_l: Layout, + pub bmnk: (usize, usize, usize, usize), + pub msg: &'static str, +} /// Main library error type. #[derive(thiserror::Error, Debug)] pub enum Error { + // === DType Errors === #[error("{msg}, expected: {expected:?}, got: {got:?}")] UnexpectedDType { msg: &'static str, @@ -10,47 +19,39 @@ pub enum Error { got: DType, }, - #[error("{msg}, expected: {expected:?}, got: {got:?}")] - UnexpectedShape { - msg: String, - expected: Shape, - got: Shape, + #[error("dtype mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] + DTypeMismatchBinaryOp { + lhs: DType, + rhs: DType, + op: &'static str, }, + #[error("unsupported dtype {0:?} for op {1}")] + UnsupportedDTypeForOp(DType, &'static str), + + // === Dimension Index Errors === #[error("{op}: dimension index {dim} out of range for {shape:?}")] DimOutOfRange { shape: Shape, - dim: usize, + dim: i32, op: &'static str, }, - #[error("invalid args for narrow: {shape:?}, dim: {dim}, start: {start}, len:{len}")] - NarrowInvalidArgs { + // === Shape Errors === + #[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")] + UnexpectedNumberOfDims { + expected: usize, + got: usize, shape: Shape, - dim: usize, - start: usize, - len: usize, }, - #[error("{op} only supports contiguous tensors")] - RequiresContiguous { op: &'static str }, - - #[error("{op} expects at least one tensor")] - OpRequiresAtLeastOneTensor { op: &'static str }, - - #[error("backward is not supported for {op}")] - BackwardNotSupported { op: &'static str }, - - #[error("{op} invalid index {index} with vocab {vocab_size}")] - InvalidIndex { - op: &'static str, - index: usize, - vocab_size: usize, + #[error("{msg}, expected: {expected:?}, got: {got:?}")] + UnexpectedShape { + msg: String, + expected: Shape, + got: Shape, }, - #[error("the candle crate has not been built with cuda support")] - NotCompiledWithCudaSupport, - #[error( "Shape mismatch, got buffer of size {buffer_size} which is compatible with shape {shape:?}" )] @@ -71,6 +72,7 @@ pub enum Error { nth_shape: Shape, }, + // === Device Errors === #[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] DeviceMismatchBinaryOp { lhs: DeviceLocation, @@ -78,27 +80,56 @@ pub enum Error { op: &'static str, }, - #[error("dtype mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] - DTypeMismatchBinaryOp { - lhs: DType, - rhs: DType, - op: &'static str, + // === Op Specific Errors === + #[error("narrow invalid args {msg}: {shape:?}, dim: {dim}, start: {start}, len:{len}")] + NarrowInvalidArgs { + shape: Shape, + dim: usize, + start: usize, + len: usize, + msg: &'static str, }, - #[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")] - UnexpectedNumberOfDims { - expected: usize, - got: usize, - shape: Shape, + #[error("conv1d invalid args {msg}: inp: {inp_shape:?}, k: {k_shape:?}, pad: {padding}, stride: {stride}")] + Conv1dInvalidArgs { + inp_shape: Shape, + k_shape: Shape, + padding: usize, + stride: usize, + msg: &'static str, }, - // TODO this is temporary when we support arbitrary matmul - #[error("temporary error where matmul doesn't support arbitrary striding {lhs_stride:?} x {rhs_stride:?}")] - UnexpectedStriding { - lhs_stride: Vec<usize>, - rhs_stride: Vec<usize>, + #[error("{op} invalid index {index} with vocab {vocab_size}")] + InvalidIndex { + op: &'static str, + index: usize, + vocab_size: usize, }, + #[error("cannot broadcast {src_shape:?} to {dst_shape:?}")] + BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape }, + + // Box indirection to avoid large variant. + #[error("{0:?}")] + MatMulUnexpectedStriding(Box<MatMulUnexpectedStriding>), + + #[error("{op} only supports contiguous tensors")] + RequiresContiguous { op: &'static str }, + + #[error("{op} expects at least one tensor")] + OpRequiresAtLeastOneTensor { op: &'static str }, + + #[error("backward is not supported for {op}")] + BackwardNotSupported { op: &'static str }, + + // === Other Errors === + #[error("the candle crate has not been built with cuda support")] + NotCompiledWithCudaSupport, + + #[error("cannot find tensor {path}")] + CannotFindTensor { path: String }, + + // === Wrapped Errors === #[error(transparent)] Cuda(Box<dyn std::error::Error + Send + Sync>), @@ -126,22 +157,6 @@ pub enum Error { #[error("unsupported safetensor dtype {0:?}")] UnsupportedSafeTensorDtype(safetensors::Dtype), - - #[error("unsupported dtype {0:?} for op {1}")] - UnsupportedDTypeForOp(DType, &'static str), - - #[error("cannot broadcast {src_shape:?} to {dst_shape:?}")] - BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape }, - - #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")] - MatMulNonContiguous { - lhs_stride: Vec<usize>, - rhs_stride: Vec<usize>, - mnk: (usize, usize, usize), - }, - - #[error("cannot find tensor {path}")] - CannotFindTensor { path: String }, } pub type Result<T> = std::result::Result<T, Error>; diff --git a/candle-core/src/layout.rs b/candle-core/src/layout.rs index 3f629d50..79d40cfc 100644 --- a/candle-core/src/layout.rs +++ b/candle-core/src/layout.rs @@ -60,20 +60,26 @@ impl Layout { self.shape.is_fortran_contiguous(&self.stride) } - pub(crate) fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> { + pub(crate) fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> { let dims = self.shape().dims(); if dim >= dims.len() { - Err(Error::UnexpectedNumberOfDims { - expected: dim + 1, - got: dims.len(), + Err(Error::DimOutOfRange { shape: self.shape().clone(), + dim: dim as i32, + op: "narrow", })? } - if start + length > dims[dim] { - todo!("add a proper error: out of bounds for narrow {dim} {start} {length} {dims:?}") + if start + len > dims[dim] { + Err(Error::NarrowInvalidArgs { + shape: self.shape.clone(), + dim, + start, + len, + msg: "start + len > dim_len", + })? } let mut dims = dims.to_vec(); - dims[dim] = length; + dims[dim] = len; Ok(Self { shape: Shape::from(dims), stride: self.stride.clone(), diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index 632ef116..b5e64454 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -194,7 +194,7 @@ impl Dim for usize { if dim >= shape.dims().len() { Err(Error::DimOutOfRange { shape: shape.clone(), - dim, + dim: dim as i32, op, })? } else { @@ -207,7 +207,7 @@ impl Dim for usize { if dim > shape.dims().len() { Err(Error::DimOutOfRange { shape: shape.clone(), - dim, + dim: dim as i32, op, })? } else { @@ -221,30 +221,36 @@ pub enum D { Minus2, } +impl D { + fn out_of_range(&self, shape: &Shape, op: &'static str) -> Error { + let dim = match self { + Self::Minus1 => -1, + Self::Minus2 => -2, + }; + Error::DimOutOfRange { + shape: shape.clone(), + dim, + op, + } + } +} + impl Dim for D { fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> { let rank = shape.rank(); match self { Self::Minus1 if rank >= 1 => Ok(rank - 1), Self::Minus2 if rank >= 2 => Ok(rank - 2), - _ => Err(Error::DimOutOfRange { - shape: shape.clone(), - dim: 42, // TODO: Have an adequate error - op, - }), + _ => Err(self.out_of_range(shape, op)), } } fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> { let rank = shape.rank(); match self { - Self::Minus1 if rank >= 1 => Ok(rank), - Self::Minus2 if rank >= 2 => Ok(rank - 1), - _ => Err(Error::DimOutOfRange { - shape: shape.clone(), - dim: 42, // TODO: Have an adequate error - op, - }), + Self::Minus1 => Ok(rank), + Self::Minus2 if rank >= 1 => Ok(rank - 1), + _ => Err(self.out_of_range(shape, op)), } } } diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 5d4e106f..f9a6ebb5 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -490,7 +490,7 @@ impl Tensor { if dim >= self.dims().len() { Err(Error::DimOutOfRange { shape: self.shape().clone(), - dim, + dim: dim as i32, op, })? } else { @@ -509,6 +509,7 @@ impl Tensor { dim, start, len, + msg: "start + len > dim_len", })? } if start == 0 && dims[dim] == len { @@ -576,10 +577,22 @@ impl Tensor { let (b_size, c_in, l_in) = match *self.dims() { [b_size, c_in, l_in] => (Some(b_size), c_in, l_in), [c_in, l_in] => (None, c_in, l_in), - _ => todo!("proper error message"), + _ => Err(Error::Conv1dInvalidArgs { + inp_shape: self.shape().clone(), + k_shape: kernel.shape().clone(), + padding, + stride, + msg: "input rank is not 2 or 3", + })?, }; if c_in != c_in_k { - todo!("proper error message") + Err(Error::Conv1dInvalidArgs { + inp_shape: self.shape().clone(), + k_shape: kernel.shape().clone(), + padding, + stride, + msg: "the number of in-channels on the input doesn't match the kernel size", + })? } let params = crate::conv::ParamsConv1D { b_size, diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 7f68ae08..aa2ec401 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -157,8 +157,9 @@ impl<'a> VarBuilder<'a> { routing, safetensors, } => { - // Unwrap or 0 just to let the proper error flow. - let index = routing.get(&path).unwrap_or(&0); + let index = routing.get(&path).ok_or_else(|| Error::CannotFindTensor { + path: path.to_string(), + })?; safetensors[*index] .tensor(&path, &data.device)? .to_dtype(data.dtype)? |