diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-12-22 09:18:13 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-12-22 09:18:13 +0100 |
commit | 62ced44ea94da7062430ed6c21ff17b36f41737d (patch) | |
tree | ffcb633955da0d743b013266de9b8b45bd59a1f0 /candle-transformers | |
parent | 5c2f893e5aa21c9f7c82a00407edb6d76db1d06c (diff) | |
download | candle-62ced44ea94da7062430ed6c21ff17b36f41737d.tar.gz candle-62ced44ea94da7062430ed6c21ff17b36f41737d.tar.bz2 candle-62ced44ea94da7062430ed6c21ff17b36f41737d.zip |
Add a Context trait similar to anyhow::Context. (#2676)
* Add a Context trait similar to anyhow::Context.
* Switch two unwrap to context.
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/generation/mod.rs | 4 | ||||
-rw-r--r-- | candle-transformers/src/models/chinese_clip/vision_model.rs | 4 | ||||
-rw-r--r-- | candle-transformers/src/models/clip/vision_model.rs | 4 | ||||
-rw-r--r-- | candle-transformers/src/models/efficientnet.rs | 4 | ||||
-rw-r--r-- | candle-transformers/src/models/fastvit.rs | 4 | ||||
-rw-r--r-- | candle-transformers/src/models/llava/mod.rs | 22 | ||||
-rw-r--r-- | candle-transformers/src/models/segformer.rs | 4 |
7 files changed, 21 insertions, 25 deletions
diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs index d95a0595..85ffb59c 100644 --- a/candle-transformers/src/generation/mod.rs +++ b/candle-transformers/src/generation/mod.rs @@ -3,7 +3,7 @@ //! Functionality for modeling sampling strategies and logits processing in text generation //! with support for temperature-based sampling, top-k filtering, nucleus sampling (top-p), //! and combinations thereof. -use candle::{DType, Error, Result, Tensor}; +use candle::{Context, DType, Error, Result, Tensor}; use rand::{distributions::Distribution, SeedableRng}; #[derive(Clone, PartialEq, Debug)] @@ -45,7 +45,7 @@ impl LogitsProcessor { .enumerate() .max_by(|(_, u), (_, v)| u.total_cmp(v)) .map(|(i, _)| i as u32) - .unwrap(); + .context("empty logits")?; Ok(next_token) } diff --git a/candle-transformers/src/models/chinese_clip/vision_model.rs b/candle-transformers/src/models/chinese_clip/vision_model.rs index a20535c4..153fe833 100644 --- a/candle-transformers/src/models/chinese_clip/vision_model.rs +++ b/candle-transformers/src/models/chinese_clip/vision_model.rs @@ -6,7 +6,7 @@ //! - 💻 [Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP) //! - 💻 [GH](https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py_ -use candle::{DType, IndexOp, Module, Result, Shape, Tensor, D}; +use candle::{Context, DType, IndexOp, Module, Result, Shape, Tensor, D}; use candle_nn as nn; use super::{Activation, EncoderConfig}; @@ -363,7 +363,7 @@ impl ChineseClipVisionTransformer { .apply(&self.pre_layer_norm)?; let mut result = self.encoder.output_hidden_states(&hidden_states, None)?; - let encoder_outputs = result.last().unwrap(); + let encoder_outputs = result.last().context("no last")?; let pooled_output = encoder_outputs.i((.., 0, ..))?; result.push(self.final_layer_norm.forward(&pooled_output)?.clone()); Ok(result) diff --git a/candle-transformers/src/models/clip/vision_model.rs b/candle-transformers/src/models/clip/vision_model.rs index e64cab16..90314420 100644 --- a/candle-transformers/src/models/clip/vision_model.rs +++ b/candle-transformers/src/models/clip/vision_model.rs @@ -6,7 +6,7 @@ //! https://github.com/openai/CLIP //! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip -use candle::{IndexOp, Result, Shape, Tensor, D}; +use candle::{Context, IndexOp, Result, Shape, Tensor, D}; use candle_nn as nn; use candle_nn::Module; use nn::Conv2dConfig; @@ -149,7 +149,7 @@ impl ClipVisionTransformer { .apply(&self.embeddings)? .apply(&self.pre_layer_norm)?; let mut result = self.encoder.output_hidden_states(&hidden_states, None)?; - let encoder_outputs = result.last().unwrap(); + let encoder_outputs = result.last().context("no last")?; let pooled_output = encoder_outputs.i((.., 0, ..))?; result.push(self.final_layer_norm.forward(&pooled_output)?.clone()); Ok(result) diff --git a/candle-transformers/src/models/efficientnet.rs b/candle-transformers/src/models/efficientnet.rs index 36754f21..be695460 100644 --- a/candle-transformers/src/models/efficientnet.rs +++ b/candle-transformers/src/models/efficientnet.rs @@ -3,7 +3,7 @@ //! See: //! - ["EfficientBERT: Progressively Searching Multilayer Perceptron Architectures for BERT"](https://arxiv.org/abs/2201.00462) //! -use candle::{Result, Tensor, D}; +use candle::{Context, Result, Tensor, D}; use candle_nn as nn; use nn::{Module, VarBuilder}; @@ -289,7 +289,7 @@ impl EfficientNet { pub fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> { let f_p = p.pp("features"); let first_in_c = configs[0].input_channels; - let last_out_c = configs.last().unwrap().out_channels; + let last_out_c = configs.last().context("no last")?.out_channels; let final_out_c = 4 * last_out_c; let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?; let nconfigs = configs.len(); diff --git a/candle-transformers/src/models/fastvit.rs b/candle-transformers/src/models/fastvit.rs index 4e296653..3f8664d9 100644 --- a/candle-transformers/src/models/fastvit.rs +++ b/candle-transformers/src/models/fastvit.rs @@ -5,7 +5,7 @@ //! //! Implementation based on [timm model](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/fastvit.py) -use candle::{DType, Result, Tensor, D}; +use candle::{Context, DType, Result, Tensor, D}; use candle_nn::{ batch_norm, conv2d, conv2d_no_bias, linear, linear_no_bias, ops::sigmoid, ops::softmax, BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder, @@ -178,7 +178,7 @@ fn squeeze_and_excitation( // based on the _fuse_bn_tensor method in timm // see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602 fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> { - let (gamma, beta) = bn.weight_and_bias().unwrap(); + let (gamma, beta) = bn.weight_and_bias().context("no weight-bias")?; let mu = bn.running_mean(); let sigma = (bn.running_var() + bn.eps())?.sqrt(); let gps = (gamma / sigma)?; diff --git a/candle-transformers/src/models/llava/mod.rs b/candle-transformers/src/models/llava/mod.rs index c252dbed..bc855538 100644 --- a/candle-transformers/src/models/llava/mod.rs +++ b/candle-transformers/src/models/llava/mod.rs @@ -14,7 +14,7 @@ use crate::models::clip::vision_model::{ClipVisionConfig, ClipVisionTransformer} use crate::models::llama::{Cache, Llama}; use crate::models::with_tracing::linear; -use candle::{bail, Device, IndexOp, Result, Tensor}; +use candle::{bail, Context, Device, IndexOp, Result, Tensor}; use candle_nn::{seq, Activation, Module, Sequential, VarBuilder}; use fancy_regex::Regex; use utils::get_anyres_image_grid_shape; @@ -145,7 +145,7 @@ impl ClipVisionTower { let config = if config.is_none() { ClipVisionConfig::clip_vit_large_patch14_336() } else { - config.clone().unwrap() + config.clone().context("no config")? }; let select_layer = match select_layer { -1 | -2 => select_layer, @@ -262,14 +262,14 @@ impl LLaVA { let image_features = if mm_patch_merge_type == "flat" { image_features .iter() - .map(|x| x.flatten(0, 1).unwrap()) - .collect::<Vec<Tensor>>() + .map(|x| x.flatten(0, 1)) + .collect::<Result<Vec<Tensor>>>()? } else if mm_patch_merge_type.starts_with("spatial") { let mut new_image_features = Vec::new(); for (image_idx, image_feature) in image_features.iter().enumerate() { let new_image_feature = if image_feature.dims()[0] > 1 { - let base_image_feature = image_feature.get(0).unwrap(); - let patch_image_feature = image_feature.i(1..).unwrap(); + let base_image_feature = image_feature.get(0)?; + let patch_image_feature = image_feature.i(1..)?; let height = self.clip_vision_tower.num_patches_per_side(); let width = height; assert_eq!(height * width, base_image_feature.dims()[0]); @@ -313,16 +313,12 @@ impl LLaVA { }; Tensor::cat(&[base_image_feature, new_image_feature], 0)? } else { - let new_image_feature = image_feature.get(0).unwrap(); + let new_image_feature = image_feature.get(0)?; if mm_patch_merge_type.contains("unpad") { Tensor::cat( - &[ - new_image_feature, - self.image_newline.clone().unsqueeze(0).unwrap(), - ], + &[new_image_feature, self.image_newline.clone().unsqueeze(0)?], 0, - ) - .unwrap() + )? } else { new_image_feature } diff --git a/candle-transformers/src/models/segformer.rs b/candle-transformers/src/models/segformer.rs index 9e0461bc..6d750df2 100644 --- a/candle-transformers/src/models/segformer.rs +++ b/candle-transformers/src/models/segformer.rs @@ -15,7 +15,7 @@ //! use crate::models::with_tracing::{conv2d, linear, Conv2d, Linear}; -use candle::{Module, ModuleT, Result, Tensor, D}; +use candle::{Context, Module, ModuleT, Result, Tensor, D}; use candle_nn::{conv2d_no_bias, layer_norm, Activation, Conv2dConfig, VarBuilder}; use serde::Deserialize; use std::collections::HashMap; @@ -633,7 +633,7 @@ impl ImageClassificationModel { impl Module for ImageClassificationModel { fn forward(&self, x: &Tensor) -> Result<Tensor> { let all_hidden_states = self.segformer.forward(x)?; - let hidden_states = all_hidden_states.last().unwrap(); + let hidden_states = all_hidden_states.last().context("no last")?; let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?; let mean = hidden_states.mean(1)?; self.classifier.forward(&mean) |