summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/segment_anything/mod.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-10 10:20:18 +0100
committerGitHub <noreply@github.com>2023-09-10 10:20:18 +0100
commit35f72514f59b3fa4bd321e3e88a75f5b43cf060f (patch)
tree37dd25098bcf16293744758268a0486337d18431 /candle-transformers/src/models/segment_anything/mod.rs
parentd3f05eae8c4f2df186b46e433be101ac39fceca5 (diff)
downloadcandle-35f72514f59b3fa4bd321e3e88a75f5b43cf060f.tar.gz
candle-35f72514f59b3fa4bd321e3e88a75f5b43cf060f.tar.bz2
candle-35f72514f59b3fa4bd321e3e88a75f5b43cf060f.zip
Move more models to candle-transformers (#796)
* Move dinov2. * Move efficientnet. * Move the quantized llama model. * Move segment-anything.
Diffstat (limited to 'candle-transformers/src/models/segment_anything/mod.rs')
-rw-r--r--candle-transformers/src/models/segment_anything/mod.rs100
1 files changed, 100 insertions, 0 deletions
diff --git a/candle-transformers/src/models/segment_anything/mod.rs b/candle-transformers/src/models/segment_anything/mod.rs
new file mode 100644
index 00000000..c29db70a
--- /dev/null
+++ b/candle-transformers/src/models/segment_anything/mod.rs
@@ -0,0 +1,100 @@
+use candle::{Result, Tensor};
+use candle_nn::{Module, VarBuilder};
+
+pub mod image_encoder;
+pub mod mask_decoder;
+pub mod prompt_encoder;
+pub mod sam;
+pub mod tiny_vit;
+pub mod transformer;
+
+pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
+ let inner = if bias {
+ candle_nn::linear(in_dim, out_dim, vb)?
+ } else {
+ candle_nn::linear_no_bias(in_dim, out_dim, vb)?
+ };
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ Ok(Linear { inner, span })
+}
+
+#[derive(Debug)]
+pub struct LayerNorm2d {
+ weight: Tensor,
+ bias: Tensor,
+ num_channels: usize,
+ eps: f64,
+}
+
+impl LayerNorm2d {
+ pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
+ let weight = vb.get(num_channels, "weight")?;
+ let bias = vb.get(num_channels, "bias")?;
+ Ok(Self {
+ weight,
+ bias,
+ num_channels,
+ eps,
+ })
+ }
+}
+
+impl Module for LayerNorm2d {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let u = xs.mean_keepdim(1)?;
+ let xs = xs.broadcast_sub(&u)?;
+ let s = xs.sqr()?.mean_keepdim(1)?;
+ let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?;
+ xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)?
+ .broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?)
+ }
+}
+
+#[derive(Debug)]
+pub struct MlpBlock {
+ lin1: Linear,
+ lin2: Linear,
+ activation: candle_nn::Activation,
+ span: tracing::Span,
+}
+
+impl MlpBlock {
+ pub fn new(
+ embedding_dim: usize,
+ mlp_dim: usize,
+ activation: candle_nn::Activation,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?;
+ let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?;
+ let span = tracing::span!(tracing::Level::TRACE, "mlp-block");
+ Ok(Self {
+ lin1,
+ lin2,
+ activation,
+ span,
+ })
+ }
+}
+
+impl Module for MlpBlock {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ xs.apply(&self.lin1)?
+ .apply(&self.activation)?
+ .apply(&self.lin2)
+ }
+}
+
+#[derive(Debug)]
+pub struct Linear {
+ inner: candle_nn::Linear,
+ span: tracing::Span,
+}
+
+impl Module for Linear {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(x)
+ }
+}