diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-12-22 09:18:13 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-12-22 09:18:13 +0100 |
commit | 62ced44ea94da7062430ed6c21ff17b36f41737d (patch) | |
tree | ffcb633955da0d743b013266de9b8b45bd59a1f0 | |
parent | 5c2f893e5aa21c9f7c82a00407edb6d76db1d06c (diff) | |
download | candle-62ced44ea94da7062430ed6c21ff17b36f41737d.tar.gz candle-62ced44ea94da7062430ed6c21ff17b36f41737d.tar.bz2 candle-62ced44ea94da7062430ed6c21ff17b36f41737d.zip |
Add a Context trait similar to anyhow::Context. (#2676)
* Add a Context trait similar to anyhow::Context.
* Switch two unwrap to context.
-rw-r--r-- | candle-core/src/error.rs | 70 | ||||
-rw-r--r-- | candle-core/src/lib.rs | 2 | ||||
-rw-r--r-- | candle-core/src/pickle.rs | 8 | ||||
-rw-r--r-- | candle-core/src/quantized/gguf_file.rs | 4 | ||||
-rw-r--r-- | candle-core/src/quantized/mod.rs | 4 | ||||
-rw-r--r-- | candle-core/src/tensor_cat.rs | 4 | ||||
-rw-r--r-- | candle-transformers/src/generation/mod.rs | 4 | ||||
-rw-r--r-- | candle-transformers/src/models/chinese_clip/vision_model.rs | 4 | ||||
-rw-r--r-- | candle-transformers/src/models/clip/vision_model.rs | 4 | ||||
-rw-r--r-- | candle-transformers/src/models/efficientnet.rs | 4 | ||||
-rw-r--r-- | candle-transformers/src/models/fastvit.rs | 4 | ||||
-rw-r--r-- | candle-transformers/src/models/llava/mod.rs | 22 | ||||
-rw-r--r-- | candle-transformers/src/models/segformer.rs | 4 |
13 files changed, 97 insertions, 41 deletions
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 15604c15..85a9d230 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -9,8 +9,14 @@ pub struct MatMulUnexpectedStriding { pub msg: &'static str, } +impl std::fmt::Debug for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{self}") + } +} + /// Main library error type. -#[derive(thiserror::Error, Debug)] +#[derive(thiserror::Error)] pub enum Error { // === DType Errors === #[error("{msg}, expected: {expected:?}, got: {got:?}")] @@ -199,8 +205,14 @@ pub enum Error { UnsupportedSafeTensorDtype(safetensors::Dtype), /// Arbitrary errors wrapping. - #[error(transparent)] - Wrapped(Box<dyn std::error::Error + Send + Sync>), + #[error("{0}")] + Wrapped(Box<dyn std::fmt::Display + Send + Sync>), + + #[error("{context}\n{inner}")] + Context { + inner: Box<Self>, + context: Box<dyn std::fmt::Display + Send + Sync>, + }, /// Adding path information to an error. #[error("path: {path:?} {inner}")] @@ -218,16 +230,19 @@ pub enum Error { /// User generated error message, typically created via `bail!`. #[error("{0}")] Msg(String), + + #[error("unwrap none")] + UnwrapNone, } pub type Result<T> = std::result::Result<T, Error>; impl Error { - pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self { + pub fn wrap(err: impl std::fmt::Display + Send + Sync + 'static) -> Self { Self::Wrapped(Box::new(err)).bt() } - pub fn msg(err: impl std::error::Error) -> Self { + pub fn msg(err: impl std::fmt::Display) -> Self { Self::Msg(err.to_string()).bt() } @@ -253,6 +268,13 @@ impl Error { path: p.as_ref().to_path_buf(), } } + + pub fn context(self, c: impl std::fmt::Display + Send + Sync + 'static) -> Self { + Self::Context { + inner: Box::new(self), + context: Box::new(c), + } + } } #[macro_export] @@ -275,3 +297,41 @@ pub fn zip<T, U>(r1: Result<T>, r2: Result<U>) -> Result<(T, U)> { (_, Err(e)) => Err(e), } } + +// Taken from anyhow. +pub trait Context<T> { + /// Wrap the error value with additional context. + fn context<C>(self, context: C) -> Result<T> + where + C: std::fmt::Display + Send + Sync + 'static; + + /// Wrap the error value with additional context that is evaluated lazily + /// only once an error does occur. + fn with_context<C, F>(self, f: F) -> Result<T> + where + C: std::fmt::Display + Send + Sync + 'static, + F: FnOnce() -> C; +} + +impl<T> Context<T> for Option<T> { + fn context<C>(self, context: C) -> Result<T> + where + C: std::fmt::Display + Send + Sync + 'static, + { + match self { + Some(v) => Ok(v), + None => Err(Error::UnwrapNone.context(context).bt()), + } + } + + fn with_context<C, F>(self, f: F) -> Result<T> + where + C: std::fmt::Display + Send + Sync + 'static, + F: FnOnce() -> C, + { + match self { + Some(v) => Ok(v), + None => Err(Error::UnwrapNone.context(f()).bt()), + } + } +} diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 5f9a1c97..16dc8e02 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -94,7 +94,7 @@ pub use cpu_backend::{CpuStorage, CpuStorageRef}; pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3, UgIOp1}; pub use device::{Device, DeviceLocation, NdArray}; pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType}; -pub use error::{Error, Result}; +pub use error::{Context, Error, Result}; pub use indexer::{IndexOp, TensorIndexer}; pub use layout::Layout; pub use shape::{Shape, D}; diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs index 24f13d20..1632cc26 100644 --- a/candle-core/src/pickle.rs +++ b/candle-core/src/pickle.rs @@ -1,7 +1,7 @@ //! Just enough pickle support to be able to read PyTorch checkpoints. // This hardcodes objects that are required for tensor reading, we may want to make this a bit more // composable/tensor agnostic at some point. -use crate::{DType, Error as E, Layout, Result, Tensor}; +use crate::{Context, DType, Error as E, Layout, Result, Tensor}; use byteorder::{LittleEndian, ReadBytesExt}; use std::collections::HashMap; use std::io::BufRead; @@ -537,7 +537,7 @@ impl Stack { crate::bail!("setitems: not an even number of objects") } while let Some(value) = objs.pop() { - let key = objs.pop().unwrap(); + let key = objs.pop().context("empty objs")?; d.push((key, value)) } } else { @@ -557,7 +557,7 @@ impl Stack { crate::bail!("setitems: not an even number of objects") } while let Some(value) = objs.pop() { - let key = objs.pop().unwrap(); + let key = objs.pop().context("empty objs")?; pydict.push((key, value)) } self.push(Object::Dict(pydict)) @@ -661,7 +661,7 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>( if !file_name.ends_with("data.pkl") { continue; } - let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").unwrap()); + let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").context("no .pkl")?); let reader = zip.by_name(file_name)?; let mut reader = std::io::BufReader::new(reader); let mut stack = Stack::empty(); diff --git a/candle-core/src/quantized/gguf_file.rs b/candle-core/src/quantized/gguf_file.rs index ccbd59eb..2ea6c7a3 100644 --- a/candle-core/src/quantized/gguf_file.rs +++ b/candle-core/src/quantized/gguf_file.rs @@ -2,7 +2,7 @@ //! use super::{GgmlDType, QTensor}; -use crate::{Device, Result}; +use crate::{Context, Device, Result}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::collections::HashMap; @@ -338,7 +338,7 @@ impl Value { if value_type.len() != 1 { crate::bail!("multiple value-types in the same array {value_type:?}") } - value_type.into_iter().next().unwrap() + value_type.into_iter().next().context("empty value_type")? }; w.write_u32::<LittleEndian>(value_type.to_u32())?; w.write_u64::<LittleEndian>(v.len() as u64)?; diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 236f5a98..802c5691 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -1,5 +1,5 @@ //! Code for GGML and GGUF files -use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor}; +use crate::{Context, CpuStorage, DType, Device, Result, Shape, Storage, Tensor}; use k_quants::*; use std::borrow::Cow; @@ -481,7 +481,7 @@ impl crate::CustomOp1 for QTensor { crate::bail!("input tensor has only one dimension {layout:?}") } let mut dst_shape = src_shape.dims().to_vec(); - let last_k = dst_shape.pop().unwrap(); + let last_k = dst_shape.pop().context("empty dst_shape")?; if last_k != k { crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape) } diff --git a/candle-core/src/tensor_cat.rs b/candle-core/src/tensor_cat.rs index 204e7fd6..be6dfe61 100644 --- a/candle-core/src/tensor_cat.rs +++ b/candle-core/src/tensor_cat.rs @@ -1,4 +1,4 @@ -use crate::{shape::Dim, Error, Result, Shape, Tensor}; +use crate::{shape::Dim, Context, Error, Result, Shape, Tensor}; impl Tensor { /// Concatenates two or more tensors along a particular dimension. @@ -134,7 +134,7 @@ impl Tensor { .bt())? } } - let next_offset = offsets.last().unwrap() + arg.elem_count(); + let next_offset = offsets.last().context("empty offsets")? + arg.elem_count(); offsets.push(next_offset); } let shape = Shape::from(cat_dims); diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs index d95a0595..85ffb59c 100644 --- a/candle-transformers/src/generation/mod.rs +++ b/candle-transformers/src/generation/mod.rs @@ -3,7 +3,7 @@ //! Functionality for modeling sampling strategies and logits processing in text generation //! with support for temperature-based sampling, top-k filtering, nucleus sampling (top-p), //! and combinations thereof. -use candle::{DType, Error, Result, Tensor}; +use candle::{Context, DType, Error, Result, Tensor}; use rand::{distributions::Distribution, SeedableRng}; #[derive(Clone, PartialEq, Debug)] @@ -45,7 +45,7 @@ impl LogitsProcessor { .enumerate() .max_by(|(_, u), (_, v)| u.total_cmp(v)) .map(|(i, _)| i as u32) - .unwrap(); + .context("empty logits")?; Ok(next_token) } diff --git a/candle-transformers/src/models/chinese_clip/vision_model.rs b/candle-transformers/src/models/chinese_clip/vision_model.rs index a20535c4..153fe833 100644 --- a/candle-transformers/src/models/chinese_clip/vision_model.rs +++ b/candle-transformers/src/models/chinese_clip/vision_model.rs @@ -6,7 +6,7 @@ //! - 💻 [Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP) //! - 💻 [GH](https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py_ -use candle::{DType, IndexOp, Module, Result, Shape, Tensor, D}; +use candle::{Context, DType, IndexOp, Module, Result, Shape, Tensor, D}; use candle_nn as nn; use super::{Activation, EncoderConfig}; @@ -363,7 +363,7 @@ impl ChineseClipVisionTransformer { .apply(&self.pre_layer_norm)?; let mut result = self.encoder.output_hidden_states(&hidden_states, None)?; - let encoder_outputs = result.last().unwrap(); + let encoder_outputs = result.last().context("no last")?; let pooled_output = encoder_outputs.i((.., 0, ..))?; result.push(self.final_layer_norm.forward(&pooled_output)?.clone()); Ok(result) diff --git a/candle-transformers/src/models/clip/vision_model.rs b/candle-transformers/src/models/clip/vision_model.rs index e64cab16..90314420 100644 --- a/candle-transformers/src/models/clip/vision_model.rs +++ b/candle-transformers/src/models/clip/vision_model.rs @@ -6,7 +6,7 @@ //! https://github.com/openai/CLIP //! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip -use candle::{IndexOp, Result, Shape, Tensor, D}; +use candle::{Context, IndexOp, Result, Shape, Tensor, D}; use candle_nn as nn; use candle_nn::Module; use nn::Conv2dConfig; @@ -149,7 +149,7 @@ impl ClipVisionTransformer { .apply(&self.embeddings)? .apply(&self.pre_layer_norm)?; let mut result = self.encoder.output_hidden_states(&hidden_states, None)?; - let encoder_outputs = result.last().unwrap(); + let encoder_outputs = result.last().context("no last")?; let pooled_output = encoder_outputs.i((.., 0, ..))?; result.push(self.final_layer_norm.forward(&pooled_output)?.clone()); Ok(result) diff --git a/candle-transformers/src/models/efficientnet.rs b/candle-transformers/src/models/efficientnet.rs index 36754f21..be695460 100644 --- a/candle-transformers/src/models/efficientnet.rs +++ b/candle-transformers/src/models/efficientnet.rs @@ -3,7 +3,7 @@ //! See: //! - ["EfficientBERT: Progressively Searching Multilayer Perceptron Architectures for BERT"](https://arxiv.org/abs/2201.00462) //! -use candle::{Result, Tensor, D}; +use candle::{Context, Result, Tensor, D}; use candle_nn as nn; use nn::{Module, VarBuilder}; @@ -289,7 +289,7 @@ impl EfficientNet { pub fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> { let f_p = p.pp("features"); let first_in_c = configs[0].input_channels; - let last_out_c = configs.last().unwrap().out_channels; + let last_out_c = configs.last().context("no last")?.out_channels; let final_out_c = 4 * last_out_c; let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?; let nconfigs = configs.len(); diff --git a/candle-transformers/src/models/fastvit.rs b/candle-transformers/src/models/fastvit.rs index 4e296653..3f8664d9 100644 --- a/candle-transformers/src/models/fastvit.rs +++ b/candle-transformers/src/models/fastvit.rs @@ -5,7 +5,7 @@ //! //! Implementation based on [timm model](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/fastvit.py) -use candle::{DType, Result, Tensor, D}; +use candle::{Context, DType, Result, Tensor, D}; use candle_nn::{ batch_norm, conv2d, conv2d_no_bias, linear, linear_no_bias, ops::sigmoid, ops::softmax, BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder, @@ -178,7 +178,7 @@ fn squeeze_and_excitation( // based on the _fuse_bn_tensor method in timm // see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602 fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> { - let (gamma, beta) = bn.weight_and_bias().unwrap(); + let (gamma, beta) = bn.weight_and_bias().context("no weight-bias")?; let mu = bn.running_mean(); let sigma = (bn.running_var() + bn.eps())?.sqrt(); let gps = (gamma / sigma)?; diff --git a/candle-transformers/src/models/llava/mod.rs b/candle-transformers/src/models/llava/mod.rs index c252dbed..bc855538 100644 --- a/candle-transformers/src/models/llava/mod.rs +++ b/candle-transformers/src/models/llava/mod.rs @@ -14,7 +14,7 @@ use crate::models::clip::vision_model::{ClipVisionConfig, ClipVisionTransformer} use crate::models::llama::{Cache, Llama}; use crate::models::with_tracing::linear; -use candle::{bail, Device, IndexOp, Result, Tensor}; +use candle::{bail, Context, Device, IndexOp, Result, Tensor}; use candle_nn::{seq, Activation, Module, Sequential, VarBuilder}; use fancy_regex::Regex; use utils::get_anyres_image_grid_shape; @@ -145,7 +145,7 @@ impl ClipVisionTower { let config = if config.is_none() { ClipVisionConfig::clip_vit_large_patch14_336() } else { - config.clone().unwrap() + config.clone().context("no config")? }; let select_layer = match select_layer { -1 | -2 => select_layer, @@ -262,14 +262,14 @@ impl LLaVA { let image_features = if mm_patch_merge_type == "flat" { image_features .iter() - .map(|x| x.flatten(0, 1).unwrap()) - .collect::<Vec<Tensor>>() + .map(|x| x.flatten(0, 1)) + .collect::<Result<Vec<Tensor>>>()? } else if mm_patch_merge_type.starts_with("spatial") { let mut new_image_features = Vec::new(); for (image_idx, image_feature) in image_features.iter().enumerate() { let new_image_feature = if image_feature.dims()[0] > 1 { - let base_image_feature = image_feature.get(0).unwrap(); - let patch_image_feature = image_feature.i(1..).unwrap(); + let base_image_feature = image_feature.get(0)?; + let patch_image_feature = image_feature.i(1..)?; let height = self.clip_vision_tower.num_patches_per_side(); let width = height; assert_eq!(height * width, base_image_feature.dims()[0]); @@ -313,16 +313,12 @@ impl LLaVA { }; Tensor::cat(&[base_image_feature, new_image_feature], 0)? } else { - let new_image_feature = image_feature.get(0).unwrap(); + let new_image_feature = image_feature.get(0)?; if mm_patch_merge_type.contains("unpad") { Tensor::cat( - &[ - new_image_feature, - self.image_newline.clone().unsqueeze(0).unwrap(), - ], + &[new_image_feature, self.image_newline.clone().unsqueeze(0)?], 0, - ) - .unwrap() + )? } else { new_image_feature } diff --git a/candle-transformers/src/models/segformer.rs b/candle-transformers/src/models/segformer.rs index 9e0461bc..6d750df2 100644 --- a/candle-transformers/src/models/segformer.rs +++ b/candle-transformers/src/models/segformer.rs @@ -15,7 +15,7 @@ //! use crate::models::with_tracing::{conv2d, linear, Conv2d, Linear}; -use candle::{Module, ModuleT, Result, Tensor, D}; +use candle::{Context, Module, ModuleT, Result, Tensor, D}; use candle_nn::{conv2d_no_bias, layer_norm, Activation, Conv2dConfig, VarBuilder}; use serde::Deserialize; use std::collections::HashMap; @@ -633,7 +633,7 @@ impl ImageClassificationModel { impl Module for ImageClassificationModel { fn forward(&self, x: &Tensor) -> Result<Tensor> { let all_hidden_states = self.segformer.forward(x)?; - let hidden_states = all_hidden_states.last().unwrap(); + let hidden_states = all_hidden_states.last().context("no last")?; let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?; let mean = hidden_states.mean(1)?; self.classifier.forward(&mean) |