//! Candle-specific Error and Result use crate::{DType, DeviceLocation, Layout, MetalError, Shape}; #[derive(Debug, Clone)] pub struct MatMulUnexpectedStriding { pub lhs_l: Layout, pub rhs_l: Layout, pub bmnk: (usize, usize, usize, usize), pub msg: &'static str, } impl std::fmt::Debug for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{self}") } } /// Main library error type. #[derive(thiserror::Error)] pub enum Error { // === DType Errors === #[error("{msg}, expected: {expected:?}, got: {got:?}")] UnexpectedDType { msg: &'static str, expected: DType, got: DType, }, #[error("dtype mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] DTypeMismatchBinaryOp { lhs: DType, rhs: DType, op: &'static str, }, #[error("unsupported dtype {0:?} for op {1}")] UnsupportedDTypeForOp(DType, &'static str), // === Dimension Index Errors === #[error("{op}: dimension index {dim} out of range for shape {shape:?}")] DimOutOfRange { shape: Shape, dim: i32, op: &'static str, }, #[error("{op}: duplicate dim index {dims:?} for shape {shape:?}")] DuplicateDimIndex { shape: Shape, dims: Vec, op: &'static str, }, // === Shape Errors === #[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")] UnexpectedNumberOfDims { expected: usize, got: usize, shape: Shape, }, #[error("{msg}, expected: {expected:?}, got: {got:?}")] UnexpectedShape { msg: String, expected: Shape, got: Shape, }, #[error( "Shape mismatch, got buffer of size {buffer_size} which is compatible with shape {shape:?}" )] ShapeMismatch { buffer_size: usize, shape: Shape }, #[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] ShapeMismatchBinaryOp { lhs: Shape, rhs: Shape, op: &'static str, }, #[error("shape mismatch in cat for dim {dim}, shape for arg 1: {first_shape:?} shape for arg {n}: {nth_shape:?}")] ShapeMismatchCat { dim: usize, first_shape: Shape, n: usize, nth_shape: Shape, }, #[error("Cannot divide tensor of shape {shape:?} equally along dim {dim} into {n_parts}")] ShapeMismatchSplit { shape: Shape, dim: usize, n_parts: usize, }, #[error("{op} can only be performed on a single dimension")] OnlySingleDimension { op: &'static str, dims: Vec }, #[error("empty tensor for {op}")] EmptyTensor { op: &'static str }, // === Device Errors === #[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] DeviceMismatchBinaryOp { lhs: DeviceLocation, rhs: DeviceLocation, op: &'static str, }, // === Op Specific Errors === #[error("narrow invalid args {msg}: {shape:?}, dim: {dim}, start: {start}, len:{len}")] NarrowInvalidArgs { shape: Shape, dim: usize, start: usize, len: usize, msg: &'static str, }, #[error("conv1d invalid args {msg}: inp: {inp_shape:?}, k: {k_shape:?}, pad: {padding}, stride: {stride}")] Conv1dInvalidArgs { inp_shape: Shape, k_shape: Shape, padding: usize, stride: usize, msg: &'static str, }, #[error("{op} invalid index {index} with dim size {size}")] InvalidIndex { op: &'static str, index: usize, size: usize, }, #[error("cannot broadcast {src_shape:?} to {dst_shape:?}")] BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape }, #[error("cannot set variable {msg}")] CannotSetVar { msg: &'static str }, // Box indirection to avoid large variant. #[error("{0:?}")] MatMulUnexpectedStriding(Box), #[error("{op} only supports contiguous tensors")] RequiresContiguous { op: &'static str }, #[error("{op} expects at least one tensor")] OpRequiresAtLeastOneTensor { op: &'static str }, #[error("{op} expects at least two tensors")] OpRequiresAtLeastTwoTensors { op: &'static str }, #[error("backward is not supported for {op}")] BackwardNotSupported { op: &'static str }, // === Other Errors === #[error("the candle crate has not been built with cuda support")] NotCompiledWithCudaSupport, #[error("the candle crate has not been built with metal support")] NotCompiledWithMetalSupport, #[error("cannot find tensor {path}")] CannotFindTensor { path: String }, // === Wrapped Errors === #[error(transparent)] Cuda(Box), #[error("Metal error {0}")] Metal(#[from] MetalError), #[error(transparent)] Ug(#[from] ug::Error), #[error(transparent)] TryFromIntError(#[from] core::num::TryFromIntError), #[error("npy/npz error {0}")] Npy(String), /// Zip file format error. #[error(transparent)] Zip(#[from] zip::result::ZipError), /// Integer parse error. #[error(transparent)] ParseInt(#[from] std::num::ParseIntError), /// Utf8 parse error. #[error(transparent)] FromUtf8(#[from] std::string::FromUtf8Error), /// I/O error. #[error(transparent)] Io(#[from] std::io::Error), /// SafeTensor error. #[error(transparent)] SafeTensor(#[from] safetensors::SafeTensorError), #[error("unsupported safetensor dtype {0:?}")] UnsupportedSafeTensorDtype(safetensors::Dtype), /// Arbitrary errors wrapping. #[error("{0}")] Wrapped(Box), #[error("{context}\n{inner}")] Context { inner: Box, context: Box, }, /// Adding path information to an error. #[error("path: {path:?} {inner}")] WithPath { inner: Box, path: std::path::PathBuf, }, #[error("{inner}\n{backtrace}")] WithBacktrace { inner: Box, backtrace: Box, }, /// User generated error message, typically created via `bail!`. #[error("{0}")] Msg(String), #[error("unwrap none")] UnwrapNone, } pub type Result = std::result::Result; impl Error { pub fn wrap(err: impl std::fmt::Display + Send + Sync + 'static) -> Self { Self::Wrapped(Box::new(err)).bt() } pub fn msg(err: impl std::fmt::Display) -> Self { Self::Msg(err.to_string()).bt() } pub fn debug(err: impl std::fmt::Debug) -> Self { Self::Msg(format!("{err:?}")).bt() } pub fn bt(self) -> Self { let backtrace = std::backtrace::Backtrace::capture(); match backtrace.status() { std::backtrace::BacktraceStatus::Disabled | std::backtrace::BacktraceStatus::Unsupported => self, _ => Self::WithBacktrace { inner: Box::new(self), backtrace: Box::new(backtrace), }, } } pub fn with_path>(self, p: P) -> Self { Self::WithPath { inner: Box::new(self), path: p.as_ref().to_path_buf(), } } pub fn context(self, c: impl std::fmt::Display + Send + Sync + 'static) -> Self { Self::Context { inner: Box::new(self), context: Box::new(c), } } } #[macro_export] macro_rules! bail { ($msg:literal $(,)?) => { return Err($crate::Error::Msg(format!($msg).into()).bt()) }; ($err:expr $(,)?) => { return Err($crate::Error::Msg(format!($err).into()).bt()) }; ($fmt:expr, $($arg:tt)*) => { return Err($crate::Error::Msg(format!($fmt, $($arg)*).into()).bt()) }; } pub fn zip(r1: Result, r2: Result) -> Result<(T, U)> { match (r1, r2) { (Ok(r1), Ok(r2)) => Ok((r1, r2)), (Err(e), _) => Err(e), (_, Err(e)) => Err(e), } } // Taken from anyhow. pub trait Context { /// Wrap the error value with additional context. fn context(self, context: C) -> Result where C: std::fmt::Display + Send + Sync + 'static; /// Wrap the error value with additional context that is evaluated lazily /// only once an error does occur. fn with_context(self, f: F) -> Result where C: std::fmt::Display + Send + Sync + 'static, F: FnOnce() -> C; } impl Context for Option { fn context(self, context: C) -> Result where C: std::fmt::Display + Send + Sync + 'static, { match self { Some(v) => Ok(v), None => Err(Error::UnwrapNone.context(context).bt()), } } fn with_context(self, f: F) -> Result where C: std::fmt::Display + Send + Sync + 'static, F: FnOnce() -> C, { match self { Some(v) => Ok(v), None => Err(Error::UnwrapNone.context(f()).bt()), } } }