diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-12-22 09:18:13 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-12-22 09:18:13 +0100 |
commit | 62ced44ea94da7062430ed6c21ff17b36f41737d (patch) | |
tree | ffcb633955da0d743b013266de9b8b45bd59a1f0 /candle-transformers/src/models/segformer.rs | |
parent | 5c2f893e5aa21c9f7c82a00407edb6d76db1d06c (diff) | |
download | candle-62ced44ea94da7062430ed6c21ff17b36f41737d.tar.gz candle-62ced44ea94da7062430ed6c21ff17b36f41737d.tar.bz2 candle-62ced44ea94da7062430ed6c21ff17b36f41737d.zip |
Add a Context trait similar to anyhow::Context. (#2676)
* Add a Context trait similar to anyhow::Context.
* Switch two unwrap to context.
Diffstat (limited to 'candle-transformers/src/models/segformer.rs')
-rw-r--r-- | candle-transformers/src/models/segformer.rs | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-transformers/src/models/segformer.rs b/candle-transformers/src/models/segformer.rs index 9e0461bc..6d750df2 100644 --- a/candle-transformers/src/models/segformer.rs +++ b/candle-transformers/src/models/segformer.rs @@ -15,7 +15,7 @@ //! use crate::models::with_tracing::{conv2d, linear, Conv2d, Linear}; -use candle::{Module, ModuleT, Result, Tensor, D}; +use candle::{Context, Module, ModuleT, Result, Tensor, D}; use candle_nn::{conv2d_no_bias, layer_norm, Activation, Conv2dConfig, VarBuilder}; use serde::Deserialize; use std::collections::HashMap; @@ -633,7 +633,7 @@ impl ImageClassificationModel { impl Module for ImageClassificationModel { fn forward(&self, x: &Tensor) -> Result<Tensor> { let all_hidden_states = self.segformer.forward(x)?; - let hidden_states = all_hidden_states.last().unwrap(); + let hidden_states = all_hidden_states.last().context("no last")?; let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?; let mean = hidden_states.mean(1)?; self.classifier.forward(&mean) |