summaryrefslogtreecommitdiff
path: root/candle-examples/examples
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-09-28 14:41:28 +0200
committerGitHub <noreply@github.com>2024-09-28 14:41:28 +0200
commit62525e83526465b2d7969c8d3d0213b491b7ccc8 (patch)
tree40cfc61438d654bd0ddc8db322c617f5d9f9b5b2 /candle-examples/examples
parent2c25754281fb6672b9ebf84f6f6f5a5b12efe10d (diff)
downloadcandle-62525e83526465b2d7969c8d3d0213b491b7ccc8.tar.gz
candle-62525e83526465b2d7969c8d3d0213b491b7ccc8.tar.bz2
candle-62525e83526465b2d7969c8d3d0213b491b7ccc8.zip
Remove some extra whitelines. (#2513)
Diffstat (limited to 'candle-examples/examples')
-rw-r--r--candle-examples/examples/mobileclip/main.rs22
1 files changed, 0 insertions, 22 deletions
diff --git a/candle-examples/examples/mobileclip/main.rs b/candle-examples/examples/mobileclip/main.rs
index d505fc7c..d9615c43 100644
--- a/candle-examples/examples/mobileclip/main.rs
+++ b/candle-examples/examples/mobileclip/main.rs
@@ -60,7 +60,6 @@ fn load_images<T: AsRef<std::path::Path>>(
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];
-
for path in paths {
let tensor = candle_examples::imagenet::load_image_with_std_mean(
path,
@@ -70,9 +69,7 @@ fn load_images<T: AsRef<std::path::Path>>(
)?;
images.push(tensor);
}
-
let images = Tensor::stack(&images, 0)?;
-
Ok(images)
}
@@ -80,24 +77,17 @@ pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let model_name = args.which.model_name();
-
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
-
let model_file = if args.use_pth {
api.get("open_clip_pytorch_model.bin")?
} else {
api.get("open_clip_model.safetensors")?
};
-
let tokenizer = api.get("tokenizer.json")?;
-
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
-
let config = &args.which.config();
-
let device = candle_examples::device(args.cpu)?;
-
let vec_imgs = match args.images {
Some(imgs) => imgs,
None => vec![
@@ -105,9 +95,7 @@ pub fn main() -> anyhow::Result<()> {
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
],
};
-
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
-
let vb = if args.use_pth {
VarBuilder::from_pth(&model_file, DType::F32, &device)?
} else {
@@ -115,22 +103,15 @@ pub fn main() -> anyhow::Result<()> {
};
let model = mobileclip::MobileClipModel::new(vb, config)?;
-
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
-
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
-
let softmax_image = softmax(&logits_per_image, 1)?;
-
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
-
println!("softmax_image_vec: {:?}", softmax_image_vec);
-
let probability_vec = softmax_image_vec
.iter()
.map(|v| v * 100.0)
.collect::<Vec<f32>>();
-
let probability_per_image = probability_vec.len() / vec_imgs.len();
for (i, img) in vec_imgs.iter().enumerate() {
@@ -171,7 +152,6 @@ pub fn tokenize_sequences(
};
let mut tokens = vec![];
-
for seq in vec_seq.clone() {
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
tokens.push(encoding.get_ids().to_vec());
@@ -185,8 +165,6 @@ pub fn tokenize_sequences(
token_vec.extend(vec![pad_id; len_diff]);
}
}
-
let input_ids = Tensor::new(tokens, device)?;
-
Ok((input_ids, vec_seq))
}