diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/error.rs | 70 | ||||
-rw-r--r-- | candle-core/src/lib.rs | 2 | ||||
-rw-r--r-- | candle-core/src/pickle.rs | 8 | ||||
-rw-r--r-- | candle-core/src/quantized/gguf_file.rs | 4 | ||||
-rw-r--r-- | candle-core/src/quantized/mod.rs | 4 | ||||
-rw-r--r-- | candle-core/src/tensor_cat.rs | 4 |
6 files changed, 76 insertions, 16 deletions
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 15604c15..85a9d230 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -9,8 +9,14 @@ pub struct MatMulUnexpectedStriding { pub msg: &'static str, } +impl std::fmt::Debug for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{self}") + } +} + /// Main library error type. -#[derive(thiserror::Error, Debug)] +#[derive(thiserror::Error)] pub enum Error { // === DType Errors === #[error("{msg}, expected: {expected:?}, got: {got:?}")] @@ -199,8 +205,14 @@ pub enum Error { UnsupportedSafeTensorDtype(safetensors::Dtype), /// Arbitrary errors wrapping. - #[error(transparent)] - Wrapped(Box<dyn std::error::Error + Send + Sync>), + #[error("{0}")] + Wrapped(Box<dyn std::fmt::Display + Send + Sync>), + + #[error("{context}\n{inner}")] + Context { + inner: Box<Self>, + context: Box<dyn std::fmt::Display + Send + Sync>, + }, /// Adding path information to an error. #[error("path: {path:?} {inner}")] @@ -218,16 +230,19 @@ pub enum Error { /// User generated error message, typically created via `bail!`. #[error("{0}")] Msg(String), + + #[error("unwrap none")] + UnwrapNone, } pub type Result<T> = std::result::Result<T, Error>; impl Error { - pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self { + pub fn wrap(err: impl std::fmt::Display + Send + Sync + 'static) -> Self { Self::Wrapped(Box::new(err)).bt() } - pub fn msg(err: impl std::error::Error) -> Self { + pub fn msg(err: impl std::fmt::Display) -> Self { Self::Msg(err.to_string()).bt() } @@ -253,6 +268,13 @@ impl Error { path: p.as_ref().to_path_buf(), } } + + pub fn context(self, c: impl std::fmt::Display + Send + Sync + 'static) -> Self { + Self::Context { + inner: Box::new(self), + context: Box::new(c), + } + } } #[macro_export] @@ -275,3 +297,41 @@ pub fn zip<T, U>(r1: Result<T>, r2: Result<U>) -> Result<(T, U)> { (_, Err(e)) => Err(e), } } + +// Taken from anyhow. +pub trait Context<T> { + /// Wrap the error value with additional context. + fn context<C>(self, context: C) -> Result<T> + where + C: std::fmt::Display + Send + Sync + 'static; + + /// Wrap the error value with additional context that is evaluated lazily + /// only once an error does occur. + fn with_context<C, F>(self, f: F) -> Result<T> + where + C: std::fmt::Display + Send + Sync + 'static, + F: FnOnce() -> C; +} + +impl<T> Context<T> for Option<T> { + fn context<C>(self, context: C) -> Result<T> + where + C: std::fmt::Display + Send + Sync + 'static, + { + match self { + Some(v) => Ok(v), + None => Err(Error::UnwrapNone.context(context).bt()), + } + } + + fn with_context<C, F>(self, f: F) -> Result<T> + where + C: std::fmt::Display + Send + Sync + 'static, + F: FnOnce() -> C, + { + match self { + Some(v) => Ok(v), + None => Err(Error::UnwrapNone.context(f()).bt()), + } + } +} diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 5f9a1c97..16dc8e02 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -94,7 +94,7 @@ pub use cpu_backend::{CpuStorage, CpuStorageRef}; pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3, UgIOp1}; pub use device::{Device, DeviceLocation, NdArray}; pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType}; -pub use error::{Error, Result}; +pub use error::{Context, Error, Result}; pub use indexer::{IndexOp, TensorIndexer}; pub use layout::Layout; pub use shape::{Shape, D}; diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs index 24f13d20..1632cc26 100644 --- a/candle-core/src/pickle.rs +++ b/candle-core/src/pickle.rs @@ -1,7 +1,7 @@ //! Just enough pickle support to be able to read PyTorch checkpoints. // This hardcodes objects that are required for tensor reading, we may want to make this a bit more // composable/tensor agnostic at some point. -use crate::{DType, Error as E, Layout, Result, Tensor}; +use crate::{Context, DType, Error as E, Layout, Result, Tensor}; use byteorder::{LittleEndian, ReadBytesExt}; use std::collections::HashMap; use std::io::BufRead; @@ -537,7 +537,7 @@ impl Stack { crate::bail!("setitems: not an even number of objects") } while let Some(value) = objs.pop() { - let key = objs.pop().unwrap(); + let key = objs.pop().context("empty objs")?; d.push((key, value)) } } else { @@ -557,7 +557,7 @@ impl Stack { crate::bail!("setitems: not an even number of objects") } while let Some(value) = objs.pop() { - let key = objs.pop().unwrap(); + let key = objs.pop().context("empty objs")?; pydict.push((key, value)) } self.push(Object::Dict(pydict)) @@ -661,7 +661,7 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>( if !file_name.ends_with("data.pkl") { continue; } - let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").unwrap()); + let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").context("no .pkl")?); let reader = zip.by_name(file_name)?; let mut reader = std::io::BufReader::new(reader); let mut stack = Stack::empty(); diff --git a/candle-core/src/quantized/gguf_file.rs b/candle-core/src/quantized/gguf_file.rs index ccbd59eb..2ea6c7a3 100644 --- a/candle-core/src/quantized/gguf_file.rs +++ b/candle-core/src/quantized/gguf_file.rs @@ -2,7 +2,7 @@ //! use super::{GgmlDType, QTensor}; -use crate::{Device, Result}; +use crate::{Context, Device, Result}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::collections::HashMap; @@ -338,7 +338,7 @@ impl Value { if value_type.len() != 1 { crate::bail!("multiple value-types in the same array {value_type:?}") } - value_type.into_iter().next().unwrap() + value_type.into_iter().next().context("empty value_type")? }; w.write_u32::<LittleEndian>(value_type.to_u32())?; w.write_u64::<LittleEndian>(v.len() as u64)?; diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 236f5a98..802c5691 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -1,5 +1,5 @@ //! Code for GGML and GGUF files -use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor}; +use crate::{Context, CpuStorage, DType, Device, Result, Shape, Storage, Tensor}; use k_quants::*; use std::borrow::Cow; @@ -481,7 +481,7 @@ impl crate::CustomOp1 for QTensor { crate::bail!("input tensor has only one dimension {layout:?}") } let mut dst_shape = src_shape.dims().to_vec(); - let last_k = dst_shape.pop().unwrap(); + let last_k = dst_shape.pop().context("empty dst_shape")?; if last_k != k { crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape) } diff --git a/candle-core/src/tensor_cat.rs b/candle-core/src/tensor_cat.rs index 204e7fd6..be6dfe61 100644 --- a/candle-core/src/tensor_cat.rs +++ b/candle-core/src/tensor_cat.rs @@ -1,4 +1,4 @@ -use crate::{shape::Dim, Error, Result, Shape, Tensor}; +use crate::{shape::Dim, Context, Error, Result, Shape, Tensor}; impl Tensor { /// Concatenates two or more tensors along a particular dimension. @@ -134,7 +134,7 @@ impl Tensor { .bt())? } } - let next_offset = offsets.last().unwrap() + arg.elem_count(); + let next_offset = offsets.last().context("empty offsets")? + arg.elem_count(); offsets.push(next_offset); } let shape = Shape::from(cat_dims); |