summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/conv.rs16
-rw-r--r--candle-nn/src/lib.rs5
2 files changed, 19 insertions, 2 deletions
diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs
index b1168405..6734ab1f 100644
--- a/candle-nn/src/conv.rs
+++ b/candle-nn/src/conv.rs
@@ -302,6 +302,22 @@ pub fn conv1d(
Ok(Conv1d::new(ws, Some(bs), cfg))
}
+pub fn conv1d_no_bias(
+ in_channels: usize,
+ out_channels: usize,
+ kernel_size: usize,
+ cfg: Conv1dConfig,
+ vb: crate::VarBuilder,
+) -> Result<Conv1d> {
+ let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
+ let ws = vb.get_with_hints(
+ (out_channels, in_channels / cfg.groups, kernel_size),
+ "weight",
+ init_ws,
+ )?;
+ Ok(Conv1d::new(ws, None, cfg))
+}
+
pub fn conv_transpose1d(
in_channels: usize,
out_channels: usize,
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs
index 6306c55a..3d0e6939 100644
--- a/candle-nn/src/lib.rs
+++ b/candle-nn/src/lib.rs
@@ -19,8 +19,9 @@ pub mod var_map;
pub use activation::{prelu, Activation, PReLU};
pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig};
pub use conv::{
- conv1d, conv2d, conv2d_no_bias, conv_transpose2d, conv_transpose2d_no_bias, Conv1d,
- Conv1dConfig, Conv2d, Conv2dConfig, ConvTranspose2d, ConvTranspose2dConfig,
+ conv1d, conv1d_no_bias, conv2d, conv2d_no_bias, conv_transpose1d, conv_transpose1d_no_bias,
+ conv_transpose2d, conv_transpose2d_no_bias, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig,
+ ConvTranspose1d, ConvTranspose1dConfig, ConvTranspose2d, ConvTranspose2dConfig,
};
pub use embedding::{embedding, Embedding};
pub use func::{func, func_t, Func, FuncT};