diff options
Diffstat (limited to 'candle-nn/src')
-rw-r--r-- | candle-nn/src/conv.rs | 26 |
1 files changed, 13 insertions, 13 deletions
diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index b2483058..cfe86bfa 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -187,10 +187,10 @@ pub fn conv1d( out_channels: usize, kernel_size: usize, cfg: Conv1dConfig, - vs: crate::VarBuilder, + vb: crate::VarBuilder, ) -> Result<Conv1d> { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; - let ws = vs.get_with_hints( + let ws = vb.get_with_hints( (out_channels, in_channels / cfg.groups, kernel_size), "weight", init_ws, @@ -200,7 +200,7 @@ pub fn conv1d( lo: -bound, up: bound, }; - let bs = vs.get_with_hints(out_channels, "bias", init_bs)?; + let bs = vb.get_with_hints(out_channels, "bias", init_bs)?; Ok(Conv1d::new(ws, Some(bs), cfg)) } @@ -209,10 +209,10 @@ pub fn conv2d( out_channels: usize, kernel_size: usize, cfg: Conv2dConfig, - vs: crate::VarBuilder, + vb: crate::VarBuilder, ) -> Result<Conv2d> { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; - let ws = vs.get_with_hints( + let ws = vb.get_with_hints( ( out_channels, in_channels / cfg.groups, @@ -227,7 +227,7 @@ pub fn conv2d( lo: -bound, up: bound, }; - let bs = vs.get_with_hints(out_channels, "bias", init_bs)?; + let bs = vb.get_with_hints(out_channels, "bias", init_bs)?; Ok(Conv2d::new(ws, Some(bs), cfg)) } @@ -236,10 +236,10 @@ pub fn conv2d_no_bias( out_channels: usize, kernel_size: usize, cfg: Conv2dConfig, - vs: crate::VarBuilder, + vb: crate::VarBuilder, ) -> Result<Conv2d> { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; - let ws = vs.get_with_hints( + let ws = vb.get_with_hints( ( out_channels, in_channels / cfg.groups, @@ -257,19 +257,19 @@ pub fn conv_transpose2d( out_channels: usize, kernel_size: usize, cfg: ConvTranspose2dConfig, - vs: crate::VarBuilder, + vb: crate::VarBuilder, ) -> Result<ConvTranspose2d> { let bound = 1. / (out_channels as f64).sqrt() / kernel_size as f64; let init = crate::Init::Uniform { lo: -bound, up: bound, }; - let ws = vs.get_with_hints( + let ws = vb.get_with_hints( (in_channels, out_channels, kernel_size, kernel_size), "weight", init, )?; - let bs = vs.get_with_hints(out_channels, "bias", init)?; + let bs = vb.get_with_hints(out_channels, "bias", init)?; Ok(ConvTranspose2d::new(ws, Some(bs), cfg)) } @@ -278,14 +278,14 @@ pub fn conv_transpose2d_no_bias( out_channels: usize, kernel_size: usize, cfg: ConvTranspose2dConfig, - vs: crate::VarBuilder, + vb: crate::VarBuilder, ) -> Result<ConvTranspose2d> { let bound = 1. / (out_channels as f64).sqrt() / kernel_size as f64; let init = crate::Init::Uniform { lo: -bound, up: bound, }; - let ws = vs.get_with_hints( + let ws = vb.get_with_hints( (out_channels, in_channels, kernel_size, kernel_size), "weight", init, |