//! Convolution Layers. use crate::BatchNorm; use candle::{Result, Tensor}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Conv1dConfig { pub padding: usize, pub stride: usize, pub dilation: usize, pub groups: usize, } impl Default for Conv1dConfig { fn default() -> Self { Self { padding: 0, stride: 1, dilation: 1, groups: 1, } } } #[derive(Clone, Debug)] pub struct Conv1d { weight: Tensor, bias: Option, config: Conv1dConfig, } impl Conv1d { pub fn new(weight: Tensor, bias: Option, config: Conv1dConfig) -> Self { Self { weight, bias, config, } } pub fn config(&self) -> &Conv1dConfig { &self.config } pub fn weight(&self) -> &Tensor { &self.weight } pub fn bias(&self) -> Option<&Tensor> { self.bias.as_ref() } } impl crate::Module for Conv1d { fn forward(&self, x: &Tensor) -> Result { let x = x.conv1d( &self.weight, self.config.padding, self.config.stride, self.config.dilation, self.config.groups, )?; match &self.bias { None => Ok(x), Some(bias) => { let b = bias.dims1()?; let bias = bias.reshape((1, b, 1))?; Ok(x.broadcast_add(&bias)?) } } } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct ConvTranspose1dConfig { pub padding: usize, pub output_padding: usize, pub stride: usize, pub dilation: usize, // TODO: support groups. } impl Default for ConvTranspose1dConfig { fn default() -> Self { Self { padding: 0, output_padding: 0, stride: 1, dilation: 1, } } } #[derive(Clone, Debug)] pub struct ConvTranspose1d { weight: Tensor, bias: Option, config: ConvTranspose1dConfig, } impl ConvTranspose1d { pub fn new(weight: Tensor, bias: Option, config: ConvTranspose1dConfig) -> Self { Self { weight, bias, config, } } pub fn config(&self) -> &ConvTranspose1dConfig { &self.config } pub fn weight(&self) -> &Tensor { &self.weight } pub fn bias(&self) -> Option<&Tensor> { self.bias.as_ref() } } impl crate::Module for ConvTranspose1d { fn forward(&self, x: &Tensor) -> Result { let x = x.conv_transpose1d( &self.weight, self.config.padding, self.config.output_padding, self.config.stride, self.config.dilation, )?; match &self.bias { None => Ok(x), Some(bias) => { let b = bias.dims1()?; let bias = bias.reshape((1, b, 1, 1))?; Ok(x.broadcast_add(&bias)?) } } } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Conv2dConfig { pub padding: usize, pub stride: usize, pub dilation: usize, pub groups: usize, } impl Default for Conv2dConfig { fn default() -> Self { Self { padding: 0, stride: 1, dilation: 1, groups: 1, } } } #[derive(Clone, Debug)] pub struct Conv2d { weight: Tensor, bias: Option, config: Conv2dConfig, } impl Conv2d { pub fn new(weight: Tensor, bias: Option, config: Conv2dConfig) -> Self { Self { weight, bias, config, } } pub fn config(&self) -> &Conv2dConfig { &self.config } pub fn weight(&self) -> &Tensor { &self.weight } pub fn bias(&self) -> Option<&Tensor> { self.bias.as_ref() } pub fn absorb_bn(&self, bn: &BatchNorm) -> Result { if let Some((w_bn, b_bn)) = bn.weight_and_bias() { let std_ = w_bn.div(&((bn.running_var() + bn.eps())?.sqrt()?))?; let weight = self .weight() .broadcast_mul(&(std_.reshape((self.weight().dims4()?.0, 1, 1, 1))?))?; let bias = match &self.bias { None => b_bn.sub(&(std_.mul(bn.running_mean())?))?, Some(bias) => b_bn.add(&(std_.mul(&bias.sub(bn.running_mean())?)?))?, }; Ok(Self { weight, bias: Some(bias), config: self.config, }) } else { candle::bail!("batch norm does not have weight_and_bias") } } } impl crate::Module for Conv2d { fn forward(&self, x: &Tensor) -> Result { let x = x.conv2d( &self.weight, self.config.padding, self.config.stride, self.config.dilation, self.config.groups, )?; match &self.bias { None => Ok(x), Some(bias) => { let b = bias.dims1()?; let bias = bias.reshape((1, b, 1, 1))?; Ok(x.broadcast_add(&bias)?) } } } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct ConvTranspose2dConfig { pub padding: usize, pub output_padding: usize, pub stride: usize, pub dilation: usize, // TODO: support groups. } impl Default for ConvTranspose2dConfig { fn default() -> Self { Self { padding: 0, output_padding: 0, stride: 1, dilation: 1, } } } #[derive(Clone, Debug)] pub struct ConvTranspose2d { weight: Tensor, bias: Option, config: ConvTranspose2dConfig, } impl ConvTranspose2d { pub fn new(weight: Tensor, bias: Option, config: ConvTranspose2dConfig) -> Self { Self { weight, bias, config, } } pub fn config(&self) -> &ConvTranspose2dConfig { &self.config } pub fn weight(&self) -> &Tensor { &self.weight } pub fn bias(&self) -> Option<&Tensor> { self.bias.as_ref() } } impl crate::Module for ConvTranspose2d { fn forward(&self, x: &Tensor) -> Result { let x = x.conv_transpose2d( &self.weight, self.config.padding, self.config.output_padding, self.config.stride, self.config.dilation, )?; match &self.bias { None => Ok(x), Some(bias) => { let b = bias.dims1()?; let bias = bias.reshape((1, b, 1, 1))?; Ok(x.broadcast_add(&bias)?) } } } } pub fn conv1d( in_channels: usize, out_channels: usize, kernel_size: usize, cfg: Conv1dConfig, vb: crate::VarBuilder, ) -> Result { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; let ws = vb.get_with_hints( (out_channels, in_channels / cfg.groups, kernel_size), "weight", init_ws, )?; let bound = 1. / (in_channels as f64).sqrt(); let init_bs = crate::Init::Uniform { lo: -bound, up: bound, }; let bs = vb.get_with_hints(out_channels, "bias", init_bs)?; Ok(Conv1d::new(ws, Some(bs), cfg)) } pub fn conv1d_no_bias( in_channels: usize, out_channels: usize, kernel_size: usize, cfg: Conv1dConfig, vb: crate::VarBuilder, ) -> Result { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; let ws = vb.get_with_hints( (out_channels, in_channels / cfg.groups, kernel_size), "weight", init_ws, )?; Ok(Conv1d::new(ws, None, cfg)) } pub fn conv_transpose1d( in_channels: usize, out_channels: usize, kernel_size: usize, cfg: ConvTranspose1dConfig, vb: crate::VarBuilder, ) -> Result { let bound = 1. / (out_channels as f64 * kernel_size as f64).sqrt(); let init = crate::Init::Uniform { lo: -bound, up: bound, }; let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?; let bs = vb.get_with_hints(out_channels, "bias", init)?; Ok(ConvTranspose1d::new(ws, Some(bs), cfg)) } pub fn conv_transpose1d_no_bias( in_channels: usize, out_channels: usize, kernel_size: usize, cfg: ConvTranspose1dConfig, vb: crate::VarBuilder, ) -> Result { let bound = 1. / (out_channels as f64 * kernel_size as f64).sqrt(); let init = crate::Init::Uniform { lo: -bound, up: bound, }; let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?; Ok(ConvTranspose1d::new(ws, None, cfg)) } pub fn conv2d( in_channels: usize, out_channels: usize, kernel_size: usize, cfg: Conv2dConfig, vb: crate::VarBuilder, ) -> Result { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; let ws = vb.get_with_hints( ( out_channels, in_channels / cfg.groups, kernel_size, kernel_size, ), "weight", init_ws, )?; let bound = 1. / (in_channels as f64).sqrt(); let init_bs = crate::Init::Uniform { lo: -bound, up: bound, }; let bs = vb.get_with_hints(out_channels, "bias", init_bs)?; Ok(Conv2d::new(ws, Some(bs), cfg)) } pub fn conv2d_no_bias( in_channels: usize, out_channels: usize, kernel_size: usize, cfg: Conv2dConfig, vb: crate::VarBuilder, ) -> Result { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; let ws = vb.get_with_hints( ( out_channels, in_channels / cfg.groups, kernel_size, kernel_size, ), "weight", init_ws, )?; Ok(Conv2d::new(ws, None, cfg)) } pub fn conv_transpose2d( in_channels: usize, out_channels: usize, kernel_size: usize, cfg: ConvTranspose2dConfig, vb: crate::VarBuilder, ) -> Result { let bound = 1. / (out_channels as f64).sqrt() / kernel_size as f64; let init = crate::Init::Uniform { lo: -bound, up: bound, }; let ws = vb.get_with_hints( (in_channels, out_channels, kernel_size, kernel_size), "weight", init, )?; let bs = vb.get_with_hints(out_channels, "bias", init)?; Ok(ConvTranspose2d::new(ws, Some(bs), cfg)) } pub fn conv_transpose2d_no_bias( in_channels: usize, out_channels: usize, kernel_size: usize, cfg: ConvTranspose2dConfig, vb: crate::VarBuilder, ) -> Result { let bound = 1. / (out_channels as f64).sqrt() / kernel_size as f64; let init = crate::Init::Uniform { lo: -bound, up: bound, }; let ws = vb.get_with_hints( (in_channels, out_channels, kernel_size, kernel_size), "weight", init, )?; Ok(ConvTranspose2d::new(ws, None, cfg)) }