summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/model_prompt_encoder.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/segment-anything/model_prompt_encoder.rs')
-rw-r--r--candle-examples/examples/segment-anything/model_prompt_encoder.rs8
1 files changed, 4 insertions, 4 deletions
diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
index c6ffffd2..aab0c4fd 100644
--- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs
+++ b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
@@ -28,10 +28,10 @@ impl PostionEmbeddingRandom {
let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
let x_embed = (x_embed / w as f64)?
- .reshape((1, w))?
+ .reshape((1, ()))?
.broadcast_as((h, w))?;
let y_embed = (y_embed / h as f64)?
- .reshape((h, 1))?
+ .reshape(((), 1))?
.broadcast_as((h, w))?;
let coords = Tensor::stack(&[&x_embed, &y_embed], D::Minus1)?;
self.pe_encoding(&coords)?.permute((2, 0, 1))
@@ -163,7 +163,7 @@ impl PromptEncoder {
fn embed_boxes(&self, boxes: &Tensor) -> Result<Tensor> {
let boxes = (boxes + 0.5)?;
- let coords = boxes.reshape((boxes.elem_count() / 4, 2, 2))?;
+ let coords = boxes.reshape(((), 2, 2))?;
let corner_embedding = self
.pe_layer
.forward_with_coords(&coords, self.input_image_size)?;
@@ -200,7 +200,7 @@ impl PromptEncoder {
let dense_embeddings = match masks {
None => {
let emb = self.no_mask_embed.embeddings();
- emb.reshape((1, emb.elem_count(), 1, 1))?.expand((
+ emb.reshape((1, (), 1, 1))?.expand((
1,
emb.elem_count(),
self.image_embedding_size.0,