summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/beit/main.rs2
-rw-r--r--candle-examples/examples/blip/main.rs2
-rw-r--r--candle-examples/examples/clip/main.rs2
-rw-r--r--candle-examples/examples/eva2/main.rs2
-rw-r--r--candle-examples/examples/llava/main.rs2
-rw-r--r--candle-examples/examples/moondream/main.rs2
-rw-r--r--candle-examples/examples/segment-anything/main.rs2
-rw-r--r--candle-examples/examples/stable-diffusion/main.rs2
-rw-r--r--candle-examples/examples/trocr/image_processor.rs2
-rw-r--r--candle-examples/examples/yolo-v3/main.rs2
-rw-r--r--candle-examples/examples/yolo-v8/main.rs2
-rw-r--r--candle-examples/src/imagenet.rs2
-rw-r--r--candle-examples/src/lib.rs4
-rw-r--r--candle-onnx/src/eval.rs9
-rw-r--r--candle-onnx/tests/ops.rs72
-rw-r--r--candle-wasm-examples/blip/src/bin/m.rs2
-rw-r--r--candle-wasm-examples/moondream/src/bin/m.rs2
-rw-r--r--candle-wasm-examples/segment-anything/src/bin/m.rs2
-rw-r--r--candle-wasm-examples/yolo/src/worker.rs4
19 files changed, 93 insertions, 26 deletions
diff --git a/candle-examples/examples/beit/main.rs b/candle-examples/examples/beit/main.rs
index a256fd99..47db4c66 100644
--- a/candle-examples/examples/beit/main.rs
+++ b/candle-examples/examples/beit/main.rs
@@ -16,7 +16,7 @@ use candle_transformers::models::beit;
/// Loads an image from disk using the image crate, this returns a tensor with shape
/// (3, 384, 384). Beit special normalization is applied.
pub fn load_image384_beit_norm<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
- let img = image::io::Reader::open(p)?
+ let img = image::ImageReader::open(p)?
.decode()
.map_err(candle::Error::wrap)?
.resize_to_fill(384, 384, image::imageops::FilterType::Triangle);
diff --git a/candle-examples/examples/blip/main.rs b/candle-examples/examples/blip/main.rs
index 15e36476..d971b49d 100644
--- a/candle-examples/examples/blip/main.rs
+++ b/candle-examples/examples/blip/main.rs
@@ -55,7 +55,7 @@ const SEP_TOKEN_ID: u32 = 102;
/// Loads an image from disk using the image crate, this returns a tensor with shape
/// (3, 384, 384). OpenAI normalization is applied.
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
- let img = image::io::Reader::open(p)?
+ let img = image::ImageReader::open(p)?
.decode()
.map_err(candle::Error::wrap)?
.resize_to_fill(384, 384, image::imageops::FilterType::Triangle);
diff --git a/candle-examples/examples/clip/main.rs b/candle-examples/examples/clip/main.rs
index f301d211..d057663d 100644
--- a/candle-examples/examples/clip/main.rs
+++ b/candle-examples/examples/clip/main.rs
@@ -33,7 +33,7 @@ struct Args {
}
fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
- let img = image::io::Reader::open(path)?.decode()?;
+ let img = image::ImageReader::open(path)?.decode()?;
let (height, width) = (image_size, image_size);
let img = img.resize_to_fill(
width as u32,
diff --git a/candle-examples/examples/eva2/main.rs b/candle-examples/examples/eva2/main.rs
index 4270075d..1a3a82cc 100644
--- a/candle-examples/examples/eva2/main.rs
+++ b/candle-examples/examples/eva2/main.rs
@@ -16,7 +16,7 @@ use candle_transformers::models::eva2;
/// Loads an image from disk using the image crate, this returns a tensor with shape
/// (3, 448, 448). OpenAI normalization is applied.
pub fn load_image448_openai_norm<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
- let img = image::io::Reader::open(p)?
+ let img = image::ImageReader::open(p)?
.decode()
.map_err(candle::Error::wrap)?
.resize_to_fill(448, 448, image::imageops::FilterType::Triangle);
diff --git a/candle-examples/examples/llava/main.rs b/candle-examples/examples/llava/main.rs
index d6c911af..cb809300 100644
--- a/candle-examples/examples/llava/main.rs
+++ b/candle-examples/examples/llava/main.rs
@@ -57,7 +57,7 @@ fn load_image<T: AsRef<std::path::Path>>(
llava_config: &LLaVAConfig,
dtype: DType,
) -> Result<((u32, u32), Tensor)> {
- let img = image::io::Reader::open(path)?.decode()?;
+ let img = image::ImageReader::open(path)?.decode()?;
let img_tensor = process_image(&img, processor, llava_config)?;
Ok(((img.width(), img.height()), img_tensor.to_dtype(dtype)?))
}
diff --git a/candle-examples/examples/moondream/main.rs b/candle-examples/examples/moondream/main.rs
index a3dc67ee..6e099888 100644
--- a/candle-examples/examples/moondream/main.rs
+++ b/candle-examples/examples/moondream/main.rs
@@ -208,7 +208,7 @@ struct Args {
/// Loads an image from disk using the image crate, this returns a tensor with shape
/// (3, 378, 378).
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> candle::Result<Tensor> {
- let img = image::io::Reader::open(p)?
+ let img = image::ImageReader::open(p)?
.decode()
.map_err(candle::Error::wrap)?
.resize_to_fill(378, 378, image::imageops::FilterType::Triangle); // Adjusted to 378x378
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs
index 10f65c66..204575da 100644
--- a/candle-examples/examples/segment-anything/main.rs
+++ b/candle-examples/examples/segment-anything/main.rs
@@ -139,7 +139,7 @@ pub fn main() -> anyhow::Result<()> {
let (_one, h, w) = mask.dims3()?;
let mask = mask.expand((3, h, w))?;
- let mut img = image::io::Reader::open(&args.image)?
+ let mut img = image::ImageReader::open(&args.image)?
.decode()
.map_err(candle::Error::wrap)?;
let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs
index d424444b..b6585afa 100644
--- a/candle-examples/examples/stable-diffusion/main.rs
+++ b/candle-examples/examples/stable-diffusion/main.rs
@@ -380,7 +380,7 @@ fn text_embeddings(
}
fn image_preprocess<T: AsRef<std::path::Path>>(path: T) -> anyhow::Result<Tensor> {
- let img = image::io::Reader::open(path)?.decode()?;
+ let img = image::ImageReader::open(path)?.decode()?;
let (height, width) = (img.height() as usize, img.width() as usize);
let height = height - height % 32;
let width = width - width % 32;
diff --git a/candle-examples/examples/trocr/image_processor.rs b/candle-examples/examples/trocr/image_processor.rs
index 531caa56..3571d6d3 100644
--- a/candle-examples/examples/trocr/image_processor.rs
+++ b/candle-examples/examples/trocr/image_processor.rs
@@ -145,7 +145,7 @@ impl ViTImageProcessor {
pub fn load_images(&self, image_path: Vec<&str>) -> Result<Vec<image::DynamicImage>> {
let mut images: Vec<image::DynamicImage> = Vec::new();
for path in image_path {
- let img = image::io::Reader::open(path)?.decode().unwrap();
+ let img = image::ImageReader::open(path)?.decode().unwrap();
images.push(img);
}
diff --git a/candle-examples/examples/yolo-v3/main.rs b/candle-examples/examples/yolo-v3/main.rs
index a6574697..fb46dac2 100644
--- a/candle-examples/examples/yolo-v3/main.rs
+++ b/candle-examples/examples/yolo-v3/main.rs
@@ -159,7 +159,7 @@ pub fn main() -> Result<()> {
let net_width = darknet.width()?;
let net_height = darknet.height()?;
- let original_image = image::io::Reader::open(&image_name)?
+ let original_image = image::ImageReader::open(&image_name)?
.decode()
.map_err(candle::Error::wrap)?;
let image = {
diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs
index eb338647..084a42d5 100644
--- a/candle-examples/examples/yolo-v8/main.rs
+++ b/candle-examples/examples/yolo-v8/main.rs
@@ -390,7 +390,7 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
for image_name in args.images.iter() {
println!("processing {image_name}");
let mut image_name = std::path::PathBuf::from(image_name);
- let original_image = image::io::Reader::open(&image_name)?
+ let original_image = image::ImageReader::open(&image_name)?
.decode()
.map_err(candle::Error::wrap)?;
let (width, height) = {
diff --git a/candle-examples/src/imagenet.rs b/candle-examples/src/imagenet.rs
index 6b079870..6fcda424 100644
--- a/candle-examples/src/imagenet.rs
+++ b/candle-examples/src/imagenet.rs
@@ -3,7 +3,7 @@ use candle::{Device, Result, Tensor};
/// Loads an image from disk using the image crate at the requested resolution.
// This returns a tensor with shape (3, res, res). imagenet normalization is applied.
pub fn load_image<P: AsRef<std::path::Path>>(p: P, res: u32) -> Result<Tensor> {
- let img = image::io::Reader::open(p)?
+ let img = image::ImageReader::open(p)?
.decode()
.map_err(candle::Error::wrap)?
.resize_to_fill(res, res, image::imageops::FilterType::Triangle);
diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs
index 3308a405..5364bcb2 100644
--- a/candle-examples/src/lib.rs
+++ b/candle-examples/src/lib.rs
@@ -34,7 +34,7 @@ pub fn load_image<P: AsRef<std::path::Path>>(
p: P,
resize_longest: Option<usize>,
) -> Result<(Tensor, usize, usize)> {
- let img = image::io::Reader::open(p)?
+ let img = image::ImageReader::open(p)?
.decode()
.map_err(candle::Error::wrap)?;
let (initial_h, initial_w) = (img.height() as usize, img.width() as usize);
@@ -65,7 +65,7 @@ pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
width: usize,
height: usize,
) -> Result<Tensor> {
- let img = image::io::Reader::open(p)?
+ let img = image::ImageReader::open(p)?
.decode()
.map_err(candle::Error::wrap)?
.resize_to_fill(
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs
index f0679d5b..d02d7ff0 100644
--- a/candle-onnx/src/eval.rs
+++ b/candle-onnx/src/eval.rs
@@ -570,6 +570,11 @@ fn simple_eval_(
.map(|&i| {
if i == xs.rank() as i64 {
Ok(xs.rank())
+ } else if i < 0 {
+ // normalize_axis doesn't work correctly here
+ // because we actually want normalized with respect
+ // to the final size, not the current (off by one)
+ Ok(xs.rank() - (-i as usize) + 1)
} else {
xs.normalize_axis(i)
}
@@ -1040,8 +1045,8 @@ fn simple_eval_(
std::iter::repeat((min..max).chain((min + 1..=max).rev())).flatten()
}
let idx = if dim > 1 {
- let cycle_len = dim * 2 - 1;
- let skip = (pads_pre[i] as usize) % cycle_len;
+ let cycle_len = dim * 2 - 2;
+ let skip = cycle_len - ((pads_pre[i] as usize) % cycle_len);
let idx = zigzag(0, (dim - 1) as i64)
.skip(skip)
.take((pads_pre[i] as usize) + dim + (pads_post[i] as usize));
diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs
index bf459d5d..861b3741 100644
--- a/candle-onnx/tests/ops.rs
+++ b/candle-onnx/tests/ops.rs
@@ -977,7 +977,59 @@ fn test_constant_of_shape() -> Result<()> {
}
// "Unsqueeze"
-// #[test]
+#[test]
+fn test_unsqueeze() -> Result<()> {
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Unsqueeze".to_string(),
+ domain: "".to_string(),
+ attribute: vec![],
+ input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![ValueInfoProto {
+ name: INPUT_X.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+ let x = Tensor::from_vec(
+ vec![
+ 1.0f32, 2.0f32, //
+ 3.0f32, 4.0f32, //
+ ],
+ &[2, 2],
+ &Device::Cpu,
+ )?;
+ let y = Tensor::from_vec(vec![-1i64], &[1], &Device::Cpu)?;
+
+ let inputs = HashMap::from_iter([(INPUT_X.to_string(), x.clone()), (INPUT_Y.to_string(), y)]);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
+ assert_eq!(z.dims(), &[2, 2, 1]);
+ assert_eq!(
+ z.flatten_all()?.to_vec1::<f32>()?,
+ x.flatten_all()?.to_vec1::<f32>()?
+ );
+
+ Ok(())
+}
// "Clip"
// #[test]
@@ -3268,13 +3320,23 @@ fn test_if() -> Result<()> {
#[test]
fn test_pad() -> Result<()> {
- let data = Tensor::from_vec(vec![1.0, 1.2, 2.3, 3.4, 4.5, 5.7], (3, 2), &Device::Cpu)?;
- let pads = Tensor::from_vec(vec![0i64, 2, 0, 0], (4,), &Device::Cpu)?;
+ let data = Tensor::from_vec(
+ vec![
+ 1.0, 2.0, 3.0, //
+ 4.0, 5.0, 6.0, //
+ ],
+ (2, 3),
+ &Device::Cpu,
+ )?;
+ let pads = Tensor::from_vec(vec![0i64, 1, 0, 0], (4,), &Device::Cpu)?;
let mode = "reflect";
let expected = Tensor::from_vec(
- vec![1.0, 1.2, 1.0, 1.2, 2.3, 3.4, 2.3, 3.4, 4.5, 5.7, 4.5, 5.7],
- (3, 4),
+ vec![
+ 2.0, 1.0, 2.0, 3.0, //
+ 5.0, 4.0, 5.0, 6.0, //
+ ],
+ (2, 4),
&Device::Cpu,
)?;
diff --git a/candle-wasm-examples/blip/src/bin/m.rs b/candle-wasm-examples/blip/src/bin/m.rs
index e2ba4fed..61504956 100644
--- a/candle-wasm-examples/blip/src/bin/m.rs
+++ b/candle-wasm-examples/blip/src/bin/m.rs
@@ -124,7 +124,7 @@ impl Model {
impl Model {
fn load_image(&self, image: Vec<u8>) -> Result<Tensor, JsError> {
let device = &Device::Cpu;
- let img = image::io::Reader::new(std::io::Cursor::new(image))
+ let img = image::ImageReader::new(std::io::Cursor::new(image))
.with_guessed_format()?
.decode()
.map_err(|e| JsError::new(&e.to_string()))?
diff --git a/candle-wasm-examples/moondream/src/bin/m.rs b/candle-wasm-examples/moondream/src/bin/m.rs
index 2af6c0d2..27cda1e7 100644
--- a/candle-wasm-examples/moondream/src/bin/m.rs
+++ b/candle-wasm-examples/moondream/src/bin/m.rs
@@ -195,7 +195,7 @@ impl Model {
}
impl Model {
fn load_image(&self, image: Vec<u8>) -> Result<Tensor, JsError> {
- let img = image::io::Reader::new(std::io::Cursor::new(image))
+ let img = image::ImageReader::new(std::io::Cursor::new(image))
.with_guessed_format()?
.decode()
.map_err(|e| JsError::new(&e.to_string()))?
diff --git a/candle-wasm-examples/segment-anything/src/bin/m.rs b/candle-wasm-examples/segment-anything/src/bin/m.rs
index e550a32f..38e9fe3b 100644
--- a/candle-wasm-examples/segment-anything/src/bin/m.rs
+++ b/candle-wasm-examples/segment-anything/src/bin/m.rs
@@ -38,7 +38,7 @@ impl Model {
pub fn set_image_embeddings(&mut self, image_data: Vec<u8>) -> Result<(), JsError> {
sam::console_log!("image data: {}", image_data.len());
let image_data = std::io::Cursor::new(image_data);
- let image = image::io::Reader::new(image_data)
+ let image = image::ImageReader::new(image_data)
.with_guessed_format()?
.decode()
.map_err(candle::Error::wrap)?;
diff --git a/candle-wasm-examples/yolo/src/worker.rs b/candle-wasm-examples/yolo/src/worker.rs
index 1ecef341..4480d567 100644
--- a/candle-wasm-examples/yolo/src/worker.rs
+++ b/candle-wasm-examples/yolo/src/worker.rs
@@ -48,7 +48,7 @@ impl Model {
) -> Result<Vec<Vec<Bbox>>> {
console_log!("image data: {}", image_data.len());
let image_data = std::io::Cursor::new(image_data);
- let original_image = image::io::Reader::new(image_data)
+ let original_image = image::ImageReader::new(image_data)
.with_guessed_format()?
.decode()
.map_err(candle::Error::wrap)?;
@@ -127,7 +127,7 @@ impl ModelPose {
) -> Result<Vec<Bbox>> {
console_log!("image data: {}", image_data.len());
let image_data = std::io::Cursor::new(image_data);
- let original_image = image::io::Reader::new(image_data)
+ let original_image = image::ImageReader::new(image_data)
.with_guessed_format()?
.decode()
.map_err(candle::Error::wrap)?;