diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-08 12:26:56 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-08 12:26:56 +0100 |
commit | 28c87f6a34e594aca5f558bceebc4c0a9c95911a (patch) | |
tree | 11d702a507de898a7e734aa22349657d04931fb4 /candle-examples/examples/segment-anything/model_prompt_encoder.rs | |
parent | c1453f00b11c9dd12c5aa81fb4355ce47d22d477 (diff) | |
download | candle-28c87f6a34e594aca5f558bceebc4c0a9c95911a.tar.gz candle-28c87f6a34e594aca5f558bceebc4c0a9c95911a.tar.bz2 candle-28c87f6a34e594aca5f558bceebc4c0a9c95911a.zip |
Automatic mask generator + point base mask (#773)
* Add more to the automatic mask generator.
* Add the target point.
* Fix.
* Remove the allow-unused.
* Mask post-processing.
Diffstat (limited to 'candle-examples/examples/segment-anything/model_prompt_encoder.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/model_prompt_encoder.rs | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs index e4291ebb..b401a900 100644 --- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs +++ b/candle-examples/examples/segment-anything/model_prompt_encoder.rs @@ -1,5 +1,5 @@ use candle::{DType, IndexOp, Result, Tensor, D}; -use candle_nn::{Linear, Module, VarBuilder}; +use candle_nn::VarBuilder; #[derive(Debug)] struct PostionEmbeddingRandom { @@ -24,7 +24,6 @@ impl PostionEmbeddingRandom { fn forward(&self, h: usize, w: usize) -> Result<Tensor> { let device = self.positional_encoding_gaussian_matrix.device(); - let grid = Tensor::ones((h, w), DType::F32, device)?; let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?; let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?; let x_embed = (x_embed / w as f64)? @@ -157,8 +156,9 @@ impl PromptEncoder { let point_embedding = self .pe_layer .forward_with_coords(&points, self.input_image_size)?; + let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?; let zeros = point_embedding.zeros_like()?; - let point_embeddings = labels.lt(&labels.zeros_like()?)?.where_cond( + let point_embedding = labels.lt(&labels.zeros_like()?)?.where_cond( &self .not_a_point_embed .embeddings() |