diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-08 12:26:56 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-08 12:26:56 +0100 |
commit | 28c87f6a34e594aca5f558bceebc4c0a9c95911a (patch) | |
tree | 11d702a507de898a7e734aa22349657d04931fb4 /candle-examples/examples/segment-anything/model_transformer.rs | |
parent | c1453f00b11c9dd12c5aa81fb4355ce47d22d477 (diff) | |
download | candle-28c87f6a34e594aca5f558bceebc4c0a9c95911a.tar.gz candle-28c87f6a34e594aca5f558bceebc4c0a9c95911a.tar.bz2 candle-28c87f6a34e594aca5f558bceebc4c0a9c95911a.zip |
Automatic mask generator + point base mask (#773)
* Add more to the automatic mask generator.
* Add the target point.
* Fix.
* Remove the allow-unused.
* Mask post-processing.
Diffstat (limited to 'candle-examples/examples/segment-anything/model_transformer.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/model_transformer.rs | 6 |
1 files changed, 1 insertions, 5 deletions
diff --git a/candle-examples/examples/segment-anything/model_transformer.rs b/candle-examples/examples/segment-anything/model_transformer.rs index 044dce9b..e4de27cb 100644 --- a/candle-examples/examples/segment-anything/model_transformer.rs +++ b/candle-examples/examples/segment-anything/model_transformer.rs @@ -1,4 +1,4 @@ -use candle::{DType, IndexOp, Result, Tensor, D}; +use candle::{Result, Tensor}; use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; #[derive(Debug)] @@ -7,7 +7,6 @@ struct Attention { k_proj: Linear, v_proj: Linear, out_proj: Linear, - internal_dim: usize, num_heads: usize, } @@ -28,7 +27,6 @@ impl Attention { k_proj, v_proj, out_proj, - internal_dim, num_heads, }) } @@ -85,7 +83,6 @@ impl TwoWayAttentionBlock { skip_first_layer_pe: bool, vb: VarBuilder, ) -> Result<Self> { - let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?; let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?; let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?; let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?; @@ -204,7 +201,6 @@ impl TwoWayTransformer { image_pe: &Tensor, point_embedding: &Tensor, ) -> Result<(Tensor, Tensor)> { - let (bs, c, h, w) = image_embedding.dims4()?; let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?; let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?; |