summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/error.rs70
-rw-r--r--candle-core/src/lib.rs2
-rw-r--r--candle-core/src/pickle.rs8
-rw-r--r--candle-core/src/quantized/gguf_file.rs4
-rw-r--r--candle-core/src/quantized/mod.rs4
-rw-r--r--candle-core/src/tensor_cat.rs4
6 files changed, 76 insertions, 16 deletions
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs
index 15604c15..85a9d230 100644
--- a/candle-core/src/error.rs
+++ b/candle-core/src/error.rs
@@ -9,8 +9,14 @@ pub struct MatMulUnexpectedStriding {
pub msg: &'static str,
}
+impl std::fmt::Debug for Error {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{self}")
+ }
+}
+
/// Main library error type.
-#[derive(thiserror::Error, Debug)]
+#[derive(thiserror::Error)]
pub enum Error {
// === DType Errors ===
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
@@ -199,8 +205,14 @@ pub enum Error {
UnsupportedSafeTensorDtype(safetensors::Dtype),
/// Arbitrary errors wrapping.
- #[error(transparent)]
- Wrapped(Box<dyn std::error::Error + Send + Sync>),
+ #[error("{0}")]
+ Wrapped(Box<dyn std::fmt::Display + Send + Sync>),
+
+ #[error("{context}\n{inner}")]
+ Context {
+ inner: Box<Self>,
+ context: Box<dyn std::fmt::Display + Send + Sync>,
+ },
/// Adding path information to an error.
#[error("path: {path:?} {inner}")]
@@ -218,16 +230,19 @@ pub enum Error {
/// User generated error message, typically created via `bail!`.
#[error("{0}")]
Msg(String),
+
+ #[error("unwrap none")]
+ UnwrapNone,
}
pub type Result<T> = std::result::Result<T, Error>;
impl Error {
- pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self {
+ pub fn wrap(err: impl std::fmt::Display + Send + Sync + 'static) -> Self {
Self::Wrapped(Box::new(err)).bt()
}
- pub fn msg(err: impl std::error::Error) -> Self {
+ pub fn msg(err: impl std::fmt::Display) -> Self {
Self::Msg(err.to_string()).bt()
}
@@ -253,6 +268,13 @@ impl Error {
path: p.as_ref().to_path_buf(),
}
}
+
+ pub fn context(self, c: impl std::fmt::Display + Send + Sync + 'static) -> Self {
+ Self::Context {
+ inner: Box::new(self),
+ context: Box::new(c),
+ }
+ }
}
#[macro_export]
@@ -275,3 +297,41 @@ pub fn zip<T, U>(r1: Result<T>, r2: Result<U>) -> Result<(T, U)> {
(_, Err(e)) => Err(e),
}
}
+
+// Taken from anyhow.
+pub trait Context<T> {
+ /// Wrap the error value with additional context.
+ fn context<C>(self, context: C) -> Result<T>
+ where
+ C: std::fmt::Display + Send + Sync + 'static;
+
+ /// Wrap the error value with additional context that is evaluated lazily
+ /// only once an error does occur.
+ fn with_context<C, F>(self, f: F) -> Result<T>
+ where
+ C: std::fmt::Display + Send + Sync + 'static,
+ F: FnOnce() -> C;
+}
+
+impl<T> Context<T> for Option<T> {
+ fn context<C>(self, context: C) -> Result<T>
+ where
+ C: std::fmt::Display + Send + Sync + 'static,
+ {
+ match self {
+ Some(v) => Ok(v),
+ None => Err(Error::UnwrapNone.context(context).bt()),
+ }
+ }
+
+ fn with_context<C, F>(self, f: F) -> Result<T>
+ where
+ C: std::fmt::Display + Send + Sync + 'static,
+ F: FnOnce() -> C,
+ {
+ match self {
+ Some(v) => Ok(v),
+ None => Err(Error::UnwrapNone.context(f()).bt()),
+ }
+ }
+}
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs
index 5f9a1c97..16dc8e02 100644
--- a/candle-core/src/lib.rs
+++ b/candle-core/src/lib.rs
@@ -94,7 +94,7 @@ pub use cpu_backend::{CpuStorage, CpuStorageRef};
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3, UgIOp1};
pub use device::{Device, DeviceLocation, NdArray};
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
-pub use error::{Error, Result};
+pub use error::{Context, Error, Result};
pub use indexer::{IndexOp, TensorIndexer};
pub use layout::Layout;
pub use shape::{Shape, D};
diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs
index 24f13d20..1632cc26 100644
--- a/candle-core/src/pickle.rs
+++ b/candle-core/src/pickle.rs
@@ -1,7 +1,7 @@
//! Just enough pickle support to be able to read PyTorch checkpoints.
// This hardcodes objects that are required for tensor reading, we may want to make this a bit more
// composable/tensor agnostic at some point.
-use crate::{DType, Error as E, Layout, Result, Tensor};
+use crate::{Context, DType, Error as E, Layout, Result, Tensor};
use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap;
use std::io::BufRead;
@@ -537,7 +537,7 @@ impl Stack {
crate::bail!("setitems: not an even number of objects")
}
while let Some(value) = objs.pop() {
- let key = objs.pop().unwrap();
+ let key = objs.pop().context("empty objs")?;
d.push((key, value))
}
} else {
@@ -557,7 +557,7 @@ impl Stack {
crate::bail!("setitems: not an even number of objects")
}
while let Some(value) = objs.pop() {
- let key = objs.pop().unwrap();
+ let key = objs.pop().context("empty objs")?;
pydict.push((key, value))
}
self.push(Object::Dict(pydict))
@@ -661,7 +661,7 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
if !file_name.ends_with("data.pkl") {
continue;
}
- let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").unwrap());
+ let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").context("no .pkl")?);
let reader = zip.by_name(file_name)?;
let mut reader = std::io::BufReader::new(reader);
let mut stack = Stack::empty();
diff --git a/candle-core/src/quantized/gguf_file.rs b/candle-core/src/quantized/gguf_file.rs
index ccbd59eb..2ea6c7a3 100644
--- a/candle-core/src/quantized/gguf_file.rs
+++ b/candle-core/src/quantized/gguf_file.rs
@@ -2,7 +2,7 @@
//!
use super::{GgmlDType, QTensor};
-use crate::{Device, Result};
+use crate::{Context, Device, Result};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::collections::HashMap;
@@ -338,7 +338,7 @@ impl Value {
if value_type.len() != 1 {
crate::bail!("multiple value-types in the same array {value_type:?}")
}
- value_type.into_iter().next().unwrap()
+ value_type.into_iter().next().context("empty value_type")?
};
w.write_u32::<LittleEndian>(value_type.to_u32())?;
w.write_u64::<LittleEndian>(v.len() as u64)?;
diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs
index 236f5a98..802c5691 100644
--- a/candle-core/src/quantized/mod.rs
+++ b/candle-core/src/quantized/mod.rs
@@ -1,5 +1,5 @@
//! Code for GGML and GGUF files
-use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
+use crate::{Context, CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
use k_quants::*;
use std::borrow::Cow;
@@ -481,7 +481,7 @@ impl crate::CustomOp1 for QTensor {
crate::bail!("input tensor has only one dimension {layout:?}")
}
let mut dst_shape = src_shape.dims().to_vec();
- let last_k = dst_shape.pop().unwrap();
+ let last_k = dst_shape.pop().context("empty dst_shape")?;
if last_k != k {
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
}
diff --git a/candle-core/src/tensor_cat.rs b/candle-core/src/tensor_cat.rs
index 204e7fd6..be6dfe61 100644
--- a/candle-core/src/tensor_cat.rs
+++ b/candle-core/src/tensor_cat.rs
@@ -1,4 +1,4 @@
-use crate::{shape::Dim, Error, Result, Shape, Tensor};
+use crate::{shape::Dim, Context, Error, Result, Shape, Tensor};
impl Tensor {
/// Concatenates two or more tensors along a particular dimension.
@@ -134,7 +134,7 @@ impl Tensor {
.bt())?
}
}
- let next_offset = offsets.last().unwrap() + arg.elem_count();
+ let next_offset = offsets.last().context("empty offsets")? + arg.elem_count();
offsets.push(next_offset);
}
let shape = Shape::from(cat_dims);