summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/segformer.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-12-22 09:18:13 +0100
committerGitHub <noreply@github.com>2024-12-22 09:18:13 +0100
commit62ced44ea94da7062430ed6c21ff17b36f41737d (patch)
treeffcb633955da0d743b013266de9b8b45bd59a1f0 /candle-transformers/src/models/segformer.rs
parent5c2f893e5aa21c9f7c82a00407edb6d76db1d06c (diff)
downloadcandle-62ced44ea94da7062430ed6c21ff17b36f41737d.tar.gz
candle-62ced44ea94da7062430ed6c21ff17b36f41737d.tar.bz2
candle-62ced44ea94da7062430ed6c21ff17b36f41737d.zip
Add a Context trait similar to anyhow::Context. (#2676)
* Add a Context trait similar to anyhow::Context. * Switch two unwrap to context.
Diffstat (limited to 'candle-transformers/src/models/segformer.rs')
-rw-r--r--candle-transformers/src/models/segformer.rs4
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-transformers/src/models/segformer.rs b/candle-transformers/src/models/segformer.rs
index 9e0461bc..6d750df2 100644
--- a/candle-transformers/src/models/segformer.rs
+++ b/candle-transformers/src/models/segformer.rs
@@ -15,7 +15,7 @@
//!
use crate::models::with_tracing::{conv2d, linear, Conv2d, Linear};
-use candle::{Module, ModuleT, Result, Tensor, D};
+use candle::{Context, Module, ModuleT, Result, Tensor, D};
use candle_nn::{conv2d_no_bias, layer_norm, Activation, Conv2dConfig, VarBuilder};
use serde::Deserialize;
use std::collections::HashMap;
@@ -633,7 +633,7 @@ impl ImageClassificationModel {
impl Module for ImageClassificationModel {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let all_hidden_states = self.segformer.forward(x)?;
- let hidden_states = all_hidden_states.last().unwrap();
+ let hidden_states = all_hidden_states.last().context("no last")?;
let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?;
let mean = hidden_states.mean(1)?;
self.classifier.forward(&mean)