summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-12-22 09:18:13 +0100
committerGitHub <noreply@github.com>2024-12-22 09:18:13 +0100
commit62ced44ea94da7062430ed6c21ff17b36f41737d (patch)
treeffcb633955da0d743b013266de9b8b45bd59a1f0 /candle-transformers
parent5c2f893e5aa21c9f7c82a00407edb6d76db1d06c (diff)
downloadcandle-62ced44ea94da7062430ed6c21ff17b36f41737d.tar.gz
candle-62ced44ea94da7062430ed6c21ff17b36f41737d.tar.bz2
candle-62ced44ea94da7062430ed6c21ff17b36f41737d.zip
Add a Context trait similar to anyhow::Context. (#2676)
* Add a Context trait similar to anyhow::Context. * Switch two unwrap to context.
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/generation/mod.rs4
-rw-r--r--candle-transformers/src/models/chinese_clip/vision_model.rs4
-rw-r--r--candle-transformers/src/models/clip/vision_model.rs4
-rw-r--r--candle-transformers/src/models/efficientnet.rs4
-rw-r--r--candle-transformers/src/models/fastvit.rs4
-rw-r--r--candle-transformers/src/models/llava/mod.rs22
-rw-r--r--candle-transformers/src/models/segformer.rs4
7 files changed, 21 insertions, 25 deletions
diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs
index d95a0595..85ffb59c 100644
--- a/candle-transformers/src/generation/mod.rs
+++ b/candle-transformers/src/generation/mod.rs
@@ -3,7 +3,7 @@
//! Functionality for modeling sampling strategies and logits processing in text generation
//! with support for temperature-based sampling, top-k filtering, nucleus sampling (top-p),
//! and combinations thereof.
-use candle::{DType, Error, Result, Tensor};
+use candle::{Context, DType, Error, Result, Tensor};
use rand::{distributions::Distribution, SeedableRng};
#[derive(Clone, PartialEq, Debug)]
@@ -45,7 +45,7 @@ impl LogitsProcessor {
.enumerate()
.max_by(|(_, u), (_, v)| u.total_cmp(v))
.map(|(i, _)| i as u32)
- .unwrap();
+ .context("empty logits")?;
Ok(next_token)
}
diff --git a/candle-transformers/src/models/chinese_clip/vision_model.rs b/candle-transformers/src/models/chinese_clip/vision_model.rs
index a20535c4..153fe833 100644
--- a/candle-transformers/src/models/chinese_clip/vision_model.rs
+++ b/candle-transformers/src/models/chinese_clip/vision_model.rs
@@ -6,7 +6,7 @@
//! - 💻 [Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP)
//! - 💻 [GH](https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py_
-use candle::{DType, IndexOp, Module, Result, Shape, Tensor, D};
+use candle::{Context, DType, IndexOp, Module, Result, Shape, Tensor, D};
use candle_nn as nn;
use super::{Activation, EncoderConfig};
@@ -363,7 +363,7 @@ impl ChineseClipVisionTransformer {
.apply(&self.pre_layer_norm)?;
let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
- let encoder_outputs = result.last().unwrap();
+ let encoder_outputs = result.last().context("no last")?;
let pooled_output = encoder_outputs.i((.., 0, ..))?;
result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
Ok(result)
diff --git a/candle-transformers/src/models/clip/vision_model.rs b/candle-transformers/src/models/clip/vision_model.rs
index e64cab16..90314420 100644
--- a/candle-transformers/src/models/clip/vision_model.rs
+++ b/candle-transformers/src/models/clip/vision_model.rs
@@ -6,7 +6,7 @@
//! https://github.com/openai/CLIP
//! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip
-use candle::{IndexOp, Result, Shape, Tensor, D};
+use candle::{Context, IndexOp, Result, Shape, Tensor, D};
use candle_nn as nn;
use candle_nn::Module;
use nn::Conv2dConfig;
@@ -149,7 +149,7 @@ impl ClipVisionTransformer {
.apply(&self.embeddings)?
.apply(&self.pre_layer_norm)?;
let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
- let encoder_outputs = result.last().unwrap();
+ let encoder_outputs = result.last().context("no last")?;
let pooled_output = encoder_outputs.i((.., 0, ..))?;
result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
Ok(result)
diff --git a/candle-transformers/src/models/efficientnet.rs b/candle-transformers/src/models/efficientnet.rs
index 36754f21..be695460 100644
--- a/candle-transformers/src/models/efficientnet.rs
+++ b/candle-transformers/src/models/efficientnet.rs
@@ -3,7 +3,7 @@
//! See:
//! - ["EfficientBERT: Progressively Searching Multilayer Perceptron Architectures for BERT"](https://arxiv.org/abs/2201.00462)
//!
-use candle::{Result, Tensor, D};
+use candle::{Context, Result, Tensor, D};
use candle_nn as nn;
use nn::{Module, VarBuilder};
@@ -289,7 +289,7 @@ impl EfficientNet {
pub fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> {
let f_p = p.pp("features");
let first_in_c = configs[0].input_channels;
- let last_out_c = configs.last().unwrap().out_channels;
+ let last_out_c = configs.last().context("no last")?.out_channels;
let final_out_c = 4 * last_out_c;
let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?;
let nconfigs = configs.len();
diff --git a/candle-transformers/src/models/fastvit.rs b/candle-transformers/src/models/fastvit.rs
index 4e296653..3f8664d9 100644
--- a/candle-transformers/src/models/fastvit.rs
+++ b/candle-transformers/src/models/fastvit.rs
@@ -5,7 +5,7 @@
//!
//! Implementation based on [timm model](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/fastvit.py)
-use candle::{DType, Result, Tensor, D};
+use candle::{Context, DType, Result, Tensor, D};
use candle_nn::{
batch_norm, conv2d, conv2d_no_bias, linear, linear_no_bias, ops::sigmoid, ops::softmax,
BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
@@ -178,7 +178,7 @@ fn squeeze_and_excitation(
// based on the _fuse_bn_tensor method in timm
// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602
fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
- let (gamma, beta) = bn.weight_and_bias().unwrap();
+ let (gamma, beta) = bn.weight_and_bias().context("no weight-bias")?;
let mu = bn.running_mean();
let sigma = (bn.running_var() + bn.eps())?.sqrt();
let gps = (gamma / sigma)?;
diff --git a/candle-transformers/src/models/llava/mod.rs b/candle-transformers/src/models/llava/mod.rs
index c252dbed..bc855538 100644
--- a/candle-transformers/src/models/llava/mod.rs
+++ b/candle-transformers/src/models/llava/mod.rs
@@ -14,7 +14,7 @@ use crate::models::clip::vision_model::{ClipVisionConfig, ClipVisionTransformer}
use crate::models::llama::{Cache, Llama};
use crate::models::with_tracing::linear;
-use candle::{bail, Device, IndexOp, Result, Tensor};
+use candle::{bail, Context, Device, IndexOp, Result, Tensor};
use candle_nn::{seq, Activation, Module, Sequential, VarBuilder};
use fancy_regex::Regex;
use utils::get_anyres_image_grid_shape;
@@ -145,7 +145,7 @@ impl ClipVisionTower {
let config = if config.is_none() {
ClipVisionConfig::clip_vit_large_patch14_336()
} else {
- config.clone().unwrap()
+ config.clone().context("no config")?
};
let select_layer = match select_layer {
-1 | -2 => select_layer,
@@ -262,14 +262,14 @@ impl LLaVA {
let image_features = if mm_patch_merge_type == "flat" {
image_features
.iter()
- .map(|x| x.flatten(0, 1).unwrap())
- .collect::<Vec<Tensor>>()
+ .map(|x| x.flatten(0, 1))
+ .collect::<Result<Vec<Tensor>>>()?
} else if mm_patch_merge_type.starts_with("spatial") {
let mut new_image_features = Vec::new();
for (image_idx, image_feature) in image_features.iter().enumerate() {
let new_image_feature = if image_feature.dims()[0] > 1 {
- let base_image_feature = image_feature.get(0).unwrap();
- let patch_image_feature = image_feature.i(1..).unwrap();
+ let base_image_feature = image_feature.get(0)?;
+ let patch_image_feature = image_feature.i(1..)?;
let height = self.clip_vision_tower.num_patches_per_side();
let width = height;
assert_eq!(height * width, base_image_feature.dims()[0]);
@@ -313,16 +313,12 @@ impl LLaVA {
};
Tensor::cat(&[base_image_feature, new_image_feature], 0)?
} else {
- let new_image_feature = image_feature.get(0).unwrap();
+ let new_image_feature = image_feature.get(0)?;
if mm_patch_merge_type.contains("unpad") {
Tensor::cat(
- &[
- new_image_feature,
- self.image_newline.clone().unsqueeze(0).unwrap(),
- ],
+ &[new_image_feature, self.image_newline.clone().unsqueeze(0)?],
0,
- )
- .unwrap()
+ )?
} else {
new_image_feature
}
diff --git a/candle-transformers/src/models/segformer.rs b/candle-transformers/src/models/segformer.rs
index 9e0461bc..6d750df2 100644
--- a/candle-transformers/src/models/segformer.rs
+++ b/candle-transformers/src/models/segformer.rs
@@ -15,7 +15,7 @@
//!
use crate::models::with_tracing::{conv2d, linear, Conv2d, Linear};
-use candle::{Module, ModuleT, Result, Tensor, D};
+use candle::{Context, Module, ModuleT, Result, Tensor, D};
use candle_nn::{conv2d_no_bias, layer_norm, Activation, Conv2dConfig, VarBuilder};
use serde::Deserialize;
use std::collections::HashMap;
@@ -633,7 +633,7 @@ impl ImageClassificationModel {
impl Module for ImageClassificationModel {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let all_hidden_states = self.segformer.forward(x)?;
- let hidden_states = all_hidden_states.last().unwrap();
+ let hidden_states = all_hidden_states.last().context("no last")?;
let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?;
let mean = hidden_states.mean(1)?;
self.classifier.forward(&mean)