diff options
author | shua <gpg@isthisa.email> | 2024-07-23 23:10:57 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-07-23 23:10:57 +0200 |
commit | 6056fd5c90801971733bf8126a02b31fcb76980a (patch) | |
tree | 89121977e53deeae9c967c61cfd9cd8901c5ad6c /candle-wasm-examples/moondream | |
parent | ebc9aa60bc121d4d7245385df6180c44268f0bdc (diff) | |
download | candle-6056fd5c90801971733bf8126a02b31fcb76980a.tar.gz candle-6056fd5c90801971733bf8126a02b31fcb76980a.tar.bz2 candle-6056fd5c90801971733bf8126a02b31fcb76980a.zip |
onnx: fix pad, unsqueeze (#2317)
* onnx: fix pad, unsqueeze
both implementations have off-by-one errors:
- Pad 'reflect' cycle for eg `dim==3` is `[0,1,2,1]` which has length of
4 (or `dim*2 - 2`) not 5 (current code `dim*2 - 1`)
- Unsqueeze(-1) for tensor with `dim==3` should be 3 (ie `dim+index+1`)
not 2 (ie currently `dim+index`)
in addition, Pad is incorrectly calculating the starting padding.
If we want to pad out 2 elements to the start, and we have this cycle
of indices of length 6, then we should skip 4 elements, but currently
we skip 2. A more visual representation of what's going on is below:
```
pad_start: 2
data: [a,b,c,d]
indices: [0, 1, 2, 3, 2, 1, 0, 1, 2, 3, 2, 1, 0, ..] // zigzag between 0..4
actual: skip [ c d| c b a b]
expected: ~ skip ~ [ c b| a b c d]
```
The values between `[` and `|` are padding and the values between
`|` and `]` in the example should match the original data being padded.
* Fix clippy lints.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-wasm-examples/moondream')
-rw-r--r-- | candle-wasm-examples/moondream/src/bin/m.rs | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-wasm-examples/moondream/src/bin/m.rs b/candle-wasm-examples/moondream/src/bin/m.rs index 2af6c0d2..27cda1e7 100644 --- a/candle-wasm-examples/moondream/src/bin/m.rs +++ b/candle-wasm-examples/moondream/src/bin/m.rs @@ -195,7 +195,7 @@ impl Model { } impl Model { fn load_image(&self, image: Vec<u8>) -> Result<Tensor, JsError> { - let img = image::io::Reader::new(std::io::Cursor::new(image)) + let img = image::ImageReader::new(std::io::Cursor::new(image)) .with_guessed_format()? .decode() .map_err(|e| JsError::new(&e.to_string()))? |