summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-transformers/src/models/resnet.rs118
1 files changed, 118 insertions, 0 deletions
diff --git a/candle-transformers/src/models/resnet.rs b/candle-transformers/src/models/resnet.rs
index dd6eee58..f2588e01 100644
--- a/candle-transformers/src/models/resnet.rs
+++ b/candle-transformers/src/models/resnet.rs
@@ -129,3 +129,121 @@ pub fn resnet34(num_classes: usize, vb: VarBuilder) -> Result<Func> {
pub fn resnet34_no_final_layer(vb: VarBuilder) -> Result<Func> {
resnet(None, 3, 4, 6, 3, vb)
}
+
+// Bottleneck versions for ResNet 50, 101, and 152.
+fn bottleneck_block(
+ c_in: usize,
+ c_out: usize,
+ stride: usize,
+ e: usize,
+ vb: VarBuilder,
+) -> Result<Func> {
+ let e_dim = e * c_out;
+ let conv1 = conv2d(c_in, c_out, 1, 0, 1, vb.pp("conv1"))?;
+ let bn1 = batch_norm(c_out, 1e-5, vb.pp("bn1"))?;
+ let conv2 = conv2d(c_out, c_out, 3, 1, stride, vb.pp("conv2"))?;
+ let bn2 = batch_norm(c_out, 1e-5, vb.pp("bn2"))?;
+ let conv3 = conv2d(c_out, e_dim, 1, 0, 1, vb.pp("conv3"))?;
+ let bn3 = batch_norm(e_dim, 1e-5, vb.pp("bn3"))?;
+ let downsample = downsample(c_in, e_dim, stride, vb.pp("downsample"))?;
+ Ok(Func::new(move |xs| {
+ let ys = xs
+ .apply(&conv1)?
+ .apply(&bn1)?
+ .relu()?
+ .apply(&conv2)?
+ .apply(&bn2)?
+ .relu()?
+ .apply(&conv3)?
+ .apply(&bn3)?;
+ (xs.apply(&downsample)? + ys)?.relu()
+ }))
+}
+
+fn bottleneck_layer(
+ c_in: usize,
+ c_out: usize,
+ stride: usize,
+ cnt: usize,
+ vb: VarBuilder,
+) -> Result<Func> {
+ let mut layers = Vec::with_capacity(cnt);
+ for index in 0..cnt {
+ let l_in = if index == 0 { c_in } else { 4 * c_out };
+ let stride = if index == 0 { stride } else { 1 };
+ layers.push(bottleneck_block(l_in, c_out, stride, 4, vb.pp(index))?)
+ }
+ Ok(Func::new(move |xs| {
+ let mut xs = xs.clone();
+ for layer in layers.iter() {
+ xs = xs.apply(layer)?
+ }
+ Ok(xs)
+ }))
+}
+
+fn bottleneck_resnet(
+ nclasses: Option<usize>,
+ c1: usize,
+ c2: usize,
+ c3: usize,
+ c4: usize,
+ vb: VarBuilder,
+) -> Result<Func> {
+ let conv1 = conv2d(3, 64, 7, 3, 2, vb.pp("conv1"))?;
+ let bn1 = batch_norm(64, 1e-5, vb.pp("bn1"))?;
+ let layer1 = bottleneck_layer(64, 64, 1, c1, vb.pp("layer1"))?;
+ let layer2 = bottleneck_layer(4 * 64, 128, 2, c2, vb.pp("layer2"))?;
+ let layer3 = bottleneck_layer(4 * 128, 256, 2, c3, vb.pp("layer3"))?;
+ let layer4 = bottleneck_layer(4 * 256, 512, 2, c4, vb.pp("layer4"))?;
+ let fc = match nclasses {
+ None => None,
+ Some(nclasses) => {
+ let linear = candle_nn::linear(4 * 512, nclasses, vb.pp("fc"))?;
+ Some(linear)
+ }
+ };
+ Ok(Func::new(move |xs| {
+ let xs = xs
+ .apply(&conv1)?
+ .apply(&bn1)?
+ .relu()?
+ .pad_with_same(D::Minus1, 1, 1)?
+ .pad_with_same(D::Minus2, 1, 1)?
+ .max_pool2d_with_stride(3, 2)?
+ .apply(&layer1)?
+ .apply(&layer2)?
+ .apply(&layer3)?
+ .apply(&layer4)?
+ .mean(D::Minus1)?
+ .mean(D::Minus1)?;
+ match &fc {
+ None => Ok(xs),
+ Some(fc) => xs.apply(fc),
+ }
+ }))
+}
+
+pub fn resnet50(num_classes: usize, vb: VarBuilder) -> Result<Func> {
+ bottleneck_resnet(Some(num_classes), 3, 4, 6, 3, vb)
+}
+
+pub fn resnet50_no_final_layer(vb: VarBuilder) -> Result<Func> {
+ bottleneck_resnet(None, 3, 4, 6, 3, vb)
+}
+
+pub fn resnet101(num_classes: usize, vb: VarBuilder) -> Result<Func> {
+ bottleneck_resnet(Some(num_classes), 3, 4, 23, 3, vb)
+}
+
+pub fn resnet101_no_final_layer(vb: VarBuilder) -> Result<Func> {
+ bottleneck_resnet(None, 3, 4, 23, 3, vb)
+}
+
+pub fn resnet152(num_classes: usize, vb: VarBuilder) -> Result<Func> {
+ bottleneck_resnet(Some(num_classes), 3, 8, 36, 3, vb)
+}
+
+pub fn resnet152_no_final_layer(vb: VarBuilder) -> Result<Func> {
+ bottleneck_resnet(None, 3, 8, 36, 3, vb)
+}