summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-02-17 18:50:55 +0100
committerGitHub <noreply@github.com>2024-02-17 18:50:55 +0100
commit41416d23762e6cf21aa6e1c3adb6d457f15d7071 (patch)
tree5f0c164b3dcfb70241cb48215c88100d13d5bfff /candle-nn
parent5ebcfeaf0f5af69bb2f74385e8d6b020d4a3b8df (diff)
downloadcandle-41416d23762e6cf21aa6e1c3adb6d457f15d7071.tar.gz
candle-41416d23762e6cf21aa6e1c3adb6d457f15d7071.tar.bz2
candle-41416d23762e6cf21aa6e1c3adb6d457f15d7071.zip
Expose more conv1d functions/structs. (#1726)
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/conv.rs16
-rw-r--r--candle-nn/src/lib.rs5
2 files changed, 19 insertions, 2 deletions
diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs
index b1168405..6734ab1f 100644
--- a/candle-nn/src/conv.rs
+++ b/candle-nn/src/conv.rs
@@ -302,6 +302,22 @@ pub fn conv1d(
Ok(Conv1d::new(ws, Some(bs), cfg))
}
+pub fn conv1d_no_bias(
+ in_channels: usize,
+ out_channels: usize,
+ kernel_size: usize,
+ cfg: Conv1dConfig,
+ vb: crate::VarBuilder,
+) -> Result<Conv1d> {
+ let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
+ let ws = vb.get_with_hints(
+ (out_channels, in_channels / cfg.groups, kernel_size),
+ "weight",
+ init_ws,
+ )?;
+ Ok(Conv1d::new(ws, None, cfg))
+}
+
pub fn conv_transpose1d(
in_channels: usize,
out_channels: usize,
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs
index 6306c55a..3d0e6939 100644
--- a/candle-nn/src/lib.rs
+++ b/candle-nn/src/lib.rs
@@ -19,8 +19,9 @@ pub mod var_map;
pub use activation::{prelu, Activation, PReLU};
pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig};
pub use conv::{
- conv1d, conv2d, conv2d_no_bias, conv_transpose2d, conv_transpose2d_no_bias, Conv1d,
- Conv1dConfig, Conv2d, Conv2dConfig, ConvTranspose2d, ConvTranspose2dConfig,
+ conv1d, conv1d_no_bias, conv2d, conv2d_no_bias, conv_transpose1d, conv_transpose1d_no_bias,
+ conv_transpose2d, conv_transpose2d_no_bias, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig,
+ ConvTranspose1d, ConvTranspose1dConfig, ConvTranspose2d, ConvTranspose2dConfig,
};
pub use embedding::{embedding, Embedding};
pub use func::{func, func_t, Func, FuncT};