diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-10 10:20:18 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-10 10:20:18 +0100 |
commit | 35f72514f59b3fa4bd321e3e88a75f5b43cf060f (patch) | |
tree | 37dd25098bcf16293744758268a0486337d18431 /candle-transformers/src/models/segment_anything/mod.rs | |
parent | d3f05eae8c4f2df186b46e433be101ac39fceca5 (diff) | |
download | candle-35f72514f59b3fa4bd321e3e88a75f5b43cf060f.tar.gz candle-35f72514f59b3fa4bd321e3e88a75f5b43cf060f.tar.bz2 candle-35f72514f59b3fa4bd321e3e88a75f5b43cf060f.zip |
Move more models to candle-transformers (#796)
* Move dinov2.
* Move efficientnet.
* Move the quantized llama model.
* Move segment-anything.
Diffstat (limited to 'candle-transformers/src/models/segment_anything/mod.rs')
-rw-r--r-- | candle-transformers/src/models/segment_anything/mod.rs | 100 |
1 files changed, 100 insertions, 0 deletions
diff --git a/candle-transformers/src/models/segment_anything/mod.rs b/candle-transformers/src/models/segment_anything/mod.rs new file mode 100644 index 00000000..c29db70a --- /dev/null +++ b/candle-transformers/src/models/segment_anything/mod.rs @@ -0,0 +1,100 @@ +use candle::{Result, Tensor}; +use candle_nn::{Module, VarBuilder}; + +pub mod image_encoder; +pub mod mask_decoder; +pub mod prompt_encoder; +pub mod sam; +pub mod tiny_vit; +pub mod transformer; + +pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> { + let inner = if bias { + candle_nn::linear(in_dim, out_dim, vb)? + } else { + candle_nn::linear_no_bias(in_dim, out_dim, vb)? + }; + let span = tracing::span!(tracing::Level::TRACE, "linear"); + Ok(Linear { inner, span }) +} + +#[derive(Debug)] +pub struct LayerNorm2d { + weight: Tensor, + bias: Tensor, + num_channels: usize, + eps: f64, +} + +impl LayerNorm2d { + pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> { + let weight = vb.get(num_channels, "weight")?; + let bias = vb.get(num_channels, "bias")?; + Ok(Self { + weight, + bias, + num_channels, + eps, + }) + } +} + +impl Module for LayerNorm2d { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let u = xs.mean_keepdim(1)?; + let xs = xs.broadcast_sub(&u)?; + let s = xs.sqr()?.mean_keepdim(1)?; + let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?; + xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)? + .broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?) + } +} + +#[derive(Debug)] +pub struct MlpBlock { + lin1: Linear, + lin2: Linear, + activation: candle_nn::Activation, + span: tracing::Span, +} + +impl MlpBlock { + pub fn new( + embedding_dim: usize, + mlp_dim: usize, + activation: candle_nn::Activation, + vb: VarBuilder, + ) -> Result<Self> { + let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?; + let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?; + let span = tracing::span!(tracing::Level::TRACE, "mlp-block"); + Ok(Self { + lin1, + lin2, + activation, + span, + }) + } +} + +impl Module for MlpBlock { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + xs.apply(&self.lin1)? + .apply(&self.activation)? + .apply(&self.lin2) + } +} + +#[derive(Debug)] +pub struct Linear { + inner: candle_nn::Linear, + span: tracing::Span, +} + +impl Module for Linear { + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} |