diff options
Diffstat (limited to 'candle-examples/examples/musicgen/nn.rs')
-rw-r--r-- | candle-examples/examples/musicgen/nn.rs | 20 |
1 files changed, 0 insertions, 20 deletions
diff --git a/candle-examples/examples/musicgen/nn.rs b/candle-examples/examples/musicgen/nn.rs deleted file mode 100644 index 282b3a05..00000000 --- a/candle-examples/examples/musicgen/nn.rs +++ /dev/null @@ -1,20 +0,0 @@ -use candle::Result; -use candle_nn::{Conv1d, Conv1dConfig, VarBuilder}; - -// Applies weight norm for inference by recomputing the weight tensor. This -// does not apply to training. -// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html -pub fn conv1d_weight_norm( - in_c: usize, - out_c: usize, - kernel_size: usize, - config: Conv1dConfig, - vb: VarBuilder, -) -> Result<Conv1d> { - let weight_g = vb.get((out_c, 1, 1), "weight_g")?; - let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?; - let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?; - let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?; - let bias = vb.get(out_c, "bias")?; - Ok(Conv1d::new(weight, Some(bias), config)) -} |