summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/model_transformer.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/segment-anything/model_transformer.rs')
-rw-r--r--candle-examples/examples/segment-anything/model_transformer.rs6
1 files changed, 1 insertions, 5 deletions
diff --git a/candle-examples/examples/segment-anything/model_transformer.rs b/candle-examples/examples/segment-anything/model_transformer.rs
index 044dce9b..e4de27cb 100644
--- a/candle-examples/examples/segment-anything/model_transformer.rs
+++ b/candle-examples/examples/segment-anything/model_transformer.rs
@@ -1,4 +1,4 @@
-use candle::{DType, IndexOp, Result, Tensor, D};
+use candle::{Result, Tensor};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
#[derive(Debug)]
@@ -7,7 +7,6 @@ struct Attention {
k_proj: Linear,
v_proj: Linear,
out_proj: Linear,
- internal_dim: usize,
num_heads: usize,
}
@@ -28,7 +27,6 @@ impl Attention {
k_proj,
v_proj,
out_proj,
- internal_dim,
num_heads,
})
}
@@ -85,7 +83,6 @@ impl TwoWayAttentionBlock {
skip_first_layer_pe: bool,
vb: VarBuilder,
) -> Result<Self> {
- let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?;
let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?;
let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?;
let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?;
@@ -204,7 +201,6 @@ impl TwoWayTransformer {
image_pe: &Tensor,
point_embedding: &Tensor,
) -> Result<(Tensor, Tensor)> {
- let (bs, c, h, w) = image_embedding.dims4()?;
let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?;
let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?;