diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-08 08:50:27 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-08 08:50:27 +0100 |
commit | 989a4807b151f08c651b5027cc1b547a59adf966 (patch) | |
tree | 8bce93e5da15e7961505f579ec6a4a882287283b /candle-examples/examples/segment-anything/model_prompt_encoder.rs | |
parent | 0e250aee4fcff8991c086ba0606a90db92b4e488 (diff) | |
download | candle-989a4807b151f08c651b5027cc1b547a59adf966.tar.gz candle-989a4807b151f08c651b5027cc1b547a59adf966.tar.bz2 candle-989a4807b151f08c651b5027cc1b547a59adf966.zip |
Use shape with holes. (#771)
Diffstat (limited to 'candle-examples/examples/segment-anything/model_prompt_encoder.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/model_prompt_encoder.rs | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs index c6ffffd2..aab0c4fd 100644 --- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs +++ b/candle-examples/examples/segment-anything/model_prompt_encoder.rs @@ -28,10 +28,10 @@ impl PostionEmbeddingRandom { let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?; let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?; let x_embed = (x_embed / w as f64)? - .reshape((1, w))? + .reshape((1, ()))? .broadcast_as((h, w))?; let y_embed = (y_embed / h as f64)? - .reshape((h, 1))? + .reshape(((), 1))? .broadcast_as((h, w))?; let coords = Tensor::stack(&[&x_embed, &y_embed], D::Minus1)?; self.pe_encoding(&coords)?.permute((2, 0, 1)) @@ -163,7 +163,7 @@ impl PromptEncoder { fn embed_boxes(&self, boxes: &Tensor) -> Result<Tensor> { let boxes = (boxes + 0.5)?; - let coords = boxes.reshape((boxes.elem_count() / 4, 2, 2))?; + let coords = boxes.reshape(((), 2, 2))?; let corner_embedding = self .pe_layer .forward_with_coords(&coords, self.input_image_size)?; @@ -200,7 +200,7 @@ impl PromptEncoder { let dense_embeddings = match masks { None => { let emb = self.no_mask_embed.embeddings(); - emb.reshape((1, emb.elem_count(), 1, 1))?.expand(( + emb.reshape((1, (), 1, 1))?.expand(( 1, emb.elem_count(), self.image_embedding_size.0, |