summaryrefslogtreecommitdiff
path: root/candle-core/src/error.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/error.rs')
-rw-r--r--candle-core/src/error.rs70
1 files changed, 65 insertions, 5 deletions
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs
index 15604c15..85a9d230 100644
--- a/candle-core/src/error.rs
+++ b/candle-core/src/error.rs
@@ -9,8 +9,14 @@ pub struct MatMulUnexpectedStriding {
pub msg: &'static str,
}
+impl std::fmt::Debug for Error {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{self}")
+ }
+}
+
/// Main library error type.
-#[derive(thiserror::Error, Debug)]
+#[derive(thiserror::Error)]
pub enum Error {
// === DType Errors ===
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
@@ -199,8 +205,14 @@ pub enum Error {
UnsupportedSafeTensorDtype(safetensors::Dtype),
/// Arbitrary errors wrapping.
- #[error(transparent)]
- Wrapped(Box<dyn std::error::Error + Send + Sync>),
+ #[error("{0}")]
+ Wrapped(Box<dyn std::fmt::Display + Send + Sync>),
+
+ #[error("{context}\n{inner}")]
+ Context {
+ inner: Box<Self>,
+ context: Box<dyn std::fmt::Display + Send + Sync>,
+ },
/// Adding path information to an error.
#[error("path: {path:?} {inner}")]
@@ -218,16 +230,19 @@ pub enum Error {
/// User generated error message, typically created via `bail!`.
#[error("{0}")]
Msg(String),
+
+ #[error("unwrap none")]
+ UnwrapNone,
}
pub type Result<T> = std::result::Result<T, Error>;
impl Error {
- pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self {
+ pub fn wrap(err: impl std::fmt::Display + Send + Sync + 'static) -> Self {
Self::Wrapped(Box::new(err)).bt()
}
- pub fn msg(err: impl std::error::Error) -> Self {
+ pub fn msg(err: impl std::fmt::Display) -> Self {
Self::Msg(err.to_string()).bt()
}
@@ -253,6 +268,13 @@ impl Error {
path: p.as_ref().to_path_buf(),
}
}
+
+ pub fn context(self, c: impl std::fmt::Display + Send + Sync + 'static) -> Self {
+ Self::Context {
+ inner: Box::new(self),
+ context: Box::new(c),
+ }
+ }
}
#[macro_export]
@@ -275,3 +297,41 @@ pub fn zip<T, U>(r1: Result<T>, r2: Result<U>) -> Result<(T, U)> {
(_, Err(e)) => Err(e),
}
}
+
+// Taken from anyhow.
+pub trait Context<T> {
+ /// Wrap the error value with additional context.
+ fn context<C>(self, context: C) -> Result<T>
+ where
+ C: std::fmt::Display + Send + Sync + 'static;
+
+ /// Wrap the error value with additional context that is evaluated lazily
+ /// only once an error does occur.
+ fn with_context<C, F>(self, f: F) -> Result<T>
+ where
+ C: std::fmt::Display + Send + Sync + 'static,
+ F: FnOnce() -> C;
+}
+
+impl<T> Context<T> for Option<T> {
+ fn context<C>(self, context: C) -> Result<T>
+ where
+ C: std::fmt::Display + Send + Sync + 'static,
+ {
+ match self {
+ Some(v) => Ok(v),
+ None => Err(Error::UnwrapNone.context(context).bt()),
+ }
+ }
+
+ fn with_context<C, F>(self, f: F) -> Result<T>
+ where
+ C: std::fmt::Display + Send + Sync + 'static,
+ F: FnOnce() -> C,
+ {
+ match self {
+ Some(v) => Ok(v),
+ None => Err(Error::UnwrapNone.context(f()).bt()),
+ }
+ }
+}