summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-nn/src/conv.rs41
-rw-r--r--candle-nn/src/lib.rs5
2 files changed, 45 insertions, 1 deletions
diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs
index f985cfd6..fe44c153 100644
--- a/candle-nn/src/conv.rs
+++ b/candle-nn/src/conv.rs
@@ -240,3 +240,44 @@ pub fn conv2d_no_bias(
)?;
Ok(Conv2d::new(ws, None, cfg))
}
+
+pub fn conv_transpose2d(
+ in_channels: usize,
+ out_channels: usize,
+ kernel_size: usize,
+ cfg: ConvTranspose2dConfig,
+ vs: crate::VarBuilder,
+) -> Result<ConvTranspose2d> {
+ let bound = 1. / (out_channels as f64).sqrt() / kernel_size as f64;
+ let init = crate::Init::Uniform {
+ lo: -bound,
+ up: bound,
+ };
+ let ws = vs.get_with_hints(
+ (in_channels, out_channels, kernel_size, kernel_size),
+ "weight",
+ init,
+ )?;
+ let bs = vs.get_with_hints(out_channels, "bias", init)?;
+ Ok(ConvTranspose2d::new(ws, Some(bs), cfg))
+}
+
+pub fn conv_transpose2d_no_bias(
+ in_channels: usize,
+ out_channels: usize,
+ kernel_size: usize,
+ cfg: ConvTranspose2dConfig,
+ vs: crate::VarBuilder,
+) -> Result<ConvTranspose2d> {
+ let bound = 1. / (out_channels as f64).sqrt() / kernel_size as f64;
+ let init = crate::Init::Uniform {
+ lo: -bound,
+ up: bound,
+ };
+ let ws = vs.get_with_hints(
+ (out_channels, in_channels, kernel_size, kernel_size),
+ "weight",
+ init,
+ )?;
+ Ok(ConvTranspose2d::new(ws, None, cfg))
+}
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs
index 6e268f4e..8e5580df 100644
--- a/candle-nn/src/lib.rs
+++ b/candle-nn/src/lib.rs
@@ -16,7 +16,10 @@ pub mod var_map;
pub use activation::Activation;
pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig};
-pub use conv::{conv1d, conv2d, conv2d_no_bias, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig};
+pub use conv::{
+ conv1d, conv2d, conv2d_no_bias, conv_transpose2d, conv_transpose2d_no_bias, Conv1d,
+ Conv1dConfig, Conv2d, Conv2dConfig, ConvTranspose2d, ConvTranspose2dConfig,
+};
pub use embedding::{embedding, Embedding};
pub use func::{func, Func};
pub use group_norm::{group_norm, GroupNorm};