summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/chinese_clip/vision_model.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-12-22 09:18:13 +0100
committerGitHub <noreply@github.com>2024-12-22 09:18:13 +0100
commit62ced44ea94da7062430ed6c21ff17b36f41737d (patch)
treeffcb633955da0d743b013266de9b8b45bd59a1f0 /candle-transformers/src/models/chinese_clip/vision_model.rs
parent5c2f893e5aa21c9f7c82a00407edb6d76db1d06c (diff)
downloadcandle-62ced44ea94da7062430ed6c21ff17b36f41737d.tar.gz
candle-62ced44ea94da7062430ed6c21ff17b36f41737d.tar.bz2
candle-62ced44ea94da7062430ed6c21ff17b36f41737d.zip
Add a Context trait similar to anyhow::Context. (#2676)
* Add a Context trait similar to anyhow::Context. * Switch two unwrap to context.
Diffstat (limited to 'candle-transformers/src/models/chinese_clip/vision_model.rs')
-rw-r--r--candle-transformers/src/models/chinese_clip/vision_model.rs4
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-transformers/src/models/chinese_clip/vision_model.rs b/candle-transformers/src/models/chinese_clip/vision_model.rs
index a20535c4..153fe833 100644
--- a/candle-transformers/src/models/chinese_clip/vision_model.rs
+++ b/candle-transformers/src/models/chinese_clip/vision_model.rs
@@ -6,7 +6,7 @@
//! - 💻 [Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP)
//! - 💻 [GH](https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py_
-use candle::{DType, IndexOp, Module, Result, Shape, Tensor, D};
+use candle::{Context, DType, IndexOp, Module, Result, Shape, Tensor, D};
use candle_nn as nn;
use super::{Activation, EncoderConfig};
@@ -363,7 +363,7 @@ impl ChineseClipVisionTransformer {
.apply(&self.pre_layer_norm)?;
let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
- let encoder_outputs = result.last().unwrap();
+ let encoder_outputs = result.last().context("no last")?;
let pooled_output = encoder_outputs.i((.., 0, ..))?;
result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
Ok(result)