diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-12 09:17:08 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-12 09:17:08 +0100 |
commit | a76ec797da866d92a39ca41a2c9e70de6d6d7df7 (patch) | |
tree | 56021e6e33aebaca580a72331b32cd5ae1145fe1 /candle-core/src/error.rs | |
parent | fa760759e5fa94c8486566af6dd3a456d0548221 (diff) | |
download | candle-a76ec797da866d92a39ca41a2c9e70de6d6d7df7.tar.gz candle-a76ec797da866d92a39ca41a2c9e70de6d6d7df7.tar.bz2 candle-a76ec797da866d92a39ca41a2c9e70de6d6d7df7.zip |
Cleanup the main crate error and add a couple dedicated ones (#142)
* Cosmetic cleanups to the error enum.
* More error cleanup.
* Proper error handling rather than panicing.
* Add some conv1d dedicated error.
Diffstat (limited to 'candle-core/src/error.rs')
-rw-r--r-- | candle-core/src/error.rs | 135 |
1 files changed, 75 insertions, 60 deletions
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 27fd11bb..b9131356 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -1,8 +1,17 @@ -use crate::{DType, DeviceLocation, Shape}; +use crate::{DType, DeviceLocation, Layout, Shape}; + +#[derive(Debug, Clone)] +pub struct MatMulUnexpectedStriding { + pub lhs_l: Layout, + pub rhs_l: Layout, + pub bmnk: (usize, usize, usize, usize), + pub msg: &'static str, +} /// Main library error type. #[derive(thiserror::Error, Debug)] pub enum Error { + // === DType Errors === #[error("{msg}, expected: {expected:?}, got: {got:?}")] UnexpectedDType { msg: &'static str, @@ -10,47 +19,39 @@ pub enum Error { got: DType, }, - #[error("{msg}, expected: {expected:?}, got: {got:?}")] - UnexpectedShape { - msg: String, - expected: Shape, - got: Shape, + #[error("dtype mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] + DTypeMismatchBinaryOp { + lhs: DType, + rhs: DType, + op: &'static str, }, + #[error("unsupported dtype {0:?} for op {1}")] + UnsupportedDTypeForOp(DType, &'static str), + + // === Dimension Index Errors === #[error("{op}: dimension index {dim} out of range for {shape:?}")] DimOutOfRange { shape: Shape, - dim: usize, + dim: i32, op: &'static str, }, - #[error("invalid args for narrow: {shape:?}, dim: {dim}, start: {start}, len:{len}")] - NarrowInvalidArgs { + // === Shape Errors === + #[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")] + UnexpectedNumberOfDims { + expected: usize, + got: usize, shape: Shape, - dim: usize, - start: usize, - len: usize, }, - #[error("{op} only supports contiguous tensors")] - RequiresContiguous { op: &'static str }, - - #[error("{op} expects at least one tensor")] - OpRequiresAtLeastOneTensor { op: &'static str }, - - #[error("backward is not supported for {op}")] - BackwardNotSupported { op: &'static str }, - - #[error("{op} invalid index {index} with vocab {vocab_size}")] - InvalidIndex { - op: &'static str, - index: usize, - vocab_size: usize, + #[error("{msg}, expected: {expected:?}, got: {got:?}")] + UnexpectedShape { + msg: String, + expected: Shape, + got: Shape, }, - #[error("the candle crate has not been built with cuda support")] - NotCompiledWithCudaSupport, - #[error( "Shape mismatch, got buffer of size {buffer_size} which is compatible with shape {shape:?}" )] @@ -71,6 +72,7 @@ pub enum Error { nth_shape: Shape, }, + // === Device Errors === #[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] DeviceMismatchBinaryOp { lhs: DeviceLocation, @@ -78,27 +80,56 @@ pub enum Error { op: &'static str, }, - #[error("dtype mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] - DTypeMismatchBinaryOp { - lhs: DType, - rhs: DType, - op: &'static str, + // === Op Specific Errors === + #[error("narrow invalid args {msg}: {shape:?}, dim: {dim}, start: {start}, len:{len}")] + NarrowInvalidArgs { + shape: Shape, + dim: usize, + start: usize, + len: usize, + msg: &'static str, }, - #[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")] - UnexpectedNumberOfDims { - expected: usize, - got: usize, - shape: Shape, + #[error("conv1d invalid args {msg}: inp: {inp_shape:?}, k: {k_shape:?}, pad: {padding}, stride: {stride}")] + Conv1dInvalidArgs { + inp_shape: Shape, + k_shape: Shape, + padding: usize, + stride: usize, + msg: &'static str, }, - // TODO this is temporary when we support arbitrary matmul - #[error("temporary error where matmul doesn't support arbitrary striding {lhs_stride:?} x {rhs_stride:?}")] - UnexpectedStriding { - lhs_stride: Vec<usize>, - rhs_stride: Vec<usize>, + #[error("{op} invalid index {index} with vocab {vocab_size}")] + InvalidIndex { + op: &'static str, + index: usize, + vocab_size: usize, }, + #[error("cannot broadcast {src_shape:?} to {dst_shape:?}")] + BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape }, + + // Box indirection to avoid large variant. + #[error("{0:?}")] + MatMulUnexpectedStriding(Box<MatMulUnexpectedStriding>), + + #[error("{op} only supports contiguous tensors")] + RequiresContiguous { op: &'static str }, + + #[error("{op} expects at least one tensor")] + OpRequiresAtLeastOneTensor { op: &'static str }, + + #[error("backward is not supported for {op}")] + BackwardNotSupported { op: &'static str }, + + // === Other Errors === + #[error("the candle crate has not been built with cuda support")] + NotCompiledWithCudaSupport, + + #[error("cannot find tensor {path}")] + CannotFindTensor { path: String }, + + // === Wrapped Errors === #[error(transparent)] Cuda(Box<dyn std::error::Error + Send + Sync>), @@ -126,22 +157,6 @@ pub enum Error { #[error("unsupported safetensor dtype {0:?}")] UnsupportedSafeTensorDtype(safetensors::Dtype), - - #[error("unsupported dtype {0:?} for op {1}")] - UnsupportedDTypeForOp(DType, &'static str), - - #[error("cannot broadcast {src_shape:?} to {dst_shape:?}")] - BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape }, - - #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")] - MatMulNonContiguous { - lhs_stride: Vec<usize>, - rhs_stride: Vec<usize>, - mnk: (usize, usize, usize), - }, - - #[error("cannot find tensor {path}")] - CannotFindTensor { path: String }, } pub type Result<T> = std::result::Result<T, Error>; |