summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/model_transformer.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-08 12:26:56 +0100
committerGitHub <noreply@github.com>2023-09-08 12:26:56 +0100
commit28c87f6a34e594aca5f558bceebc4c0a9c95911a (patch)
tree11d702a507de898a7e734aa22349657d04931fb4 /candle-examples/examples/segment-anything/model_transformer.rs
parentc1453f00b11c9dd12c5aa81fb4355ce47d22d477 (diff)
downloadcandle-28c87f6a34e594aca5f558bceebc4c0a9c95911a.tar.gz
candle-28c87f6a34e594aca5f558bceebc4c0a9c95911a.tar.bz2
candle-28c87f6a34e594aca5f558bceebc4c0a9c95911a.zip
Automatic mask generator + point base mask (#773)
* Add more to the automatic mask generator. * Add the target point. * Fix. * Remove the allow-unused. * Mask post-processing.
Diffstat (limited to 'candle-examples/examples/segment-anything/model_transformer.rs')
-rw-r--r--candle-examples/examples/segment-anything/model_transformer.rs6
1 files changed, 1 insertions, 5 deletions
diff --git a/candle-examples/examples/segment-anything/model_transformer.rs b/candle-examples/examples/segment-anything/model_transformer.rs
index 044dce9b..e4de27cb 100644
--- a/candle-examples/examples/segment-anything/model_transformer.rs
+++ b/candle-examples/examples/segment-anything/model_transformer.rs
@@ -1,4 +1,4 @@
-use candle::{DType, IndexOp, Result, Tensor, D};
+use candle::{Result, Tensor};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
#[derive(Debug)]
@@ -7,7 +7,6 @@ struct Attention {
k_proj: Linear,
v_proj: Linear,
out_proj: Linear,
- internal_dim: usize,
num_heads: usize,
}
@@ -28,7 +27,6 @@ impl Attention {
k_proj,
v_proj,
out_proj,
- internal_dim,
num_heads,
})
}
@@ -85,7 +83,6 @@ impl TwoWayAttentionBlock {
skip_first_layer_pe: bool,
vb: VarBuilder,
) -> Result<Self> {
- let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?;
let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?;
let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?;
let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?;
@@ -204,7 +201,6 @@ impl TwoWayTransformer {
image_pe: &Tensor,
point_embedding: &Tensor,
) -> Result<(Tensor, Tensor)> {
- let (bs, c, h, w) = image_embedding.dims4()?;
let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?;
let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?;