diff options
Diffstat (limited to 'candle-examples/examples/musicgen/nn.rs')
-rw-r--r-- | candle-examples/examples/musicgen/nn.rs | 75 |
1 files changed, 2 insertions, 73 deletions
diff --git a/candle-examples/examples/musicgen/nn.rs b/candle-examples/examples/musicgen/nn.rs index 31b1a162..282b3a05 100644 --- a/candle-examples/examples/musicgen/nn.rs +++ b/candle-examples/examples/musicgen/nn.rs @@ -1,62 +1,5 @@ -use anyhow::Result; -use candle::Tensor; - -const MAX_SEQ_LEN: usize = 5000; - -pub type VarBuilder<'a> = candle_nn::VarBuilder<'a>; -pub type Linear = candle_nn::Linear; - -pub fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> { - let weight = vb.get((size2, size1), "weight")?; - let bias = if bias { - Some(vb.get(size2, "bias")?) - } else { - None - }; - Ok(Linear::new(weight, bias)) -} - -pub type LayerNorm = candle_nn::LayerNorm; - -pub fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> { - let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) { - (Ok(weight), Ok(bias)) => (weight, bias), - (Err(err), _) | (_, Err(err)) => { - if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) { - (weight, bias) - } else { - return Err(err.into()); - } - } - }; - Ok(LayerNorm::new(weight, bias, eps)) -} - -#[derive(Debug)] -pub struct Dropout { - pr: f64, -} - -impl Dropout { - pub fn new(pr: f64) -> Self { - Self { pr } - } - - pub fn forward(&self, x: &Tensor) -> Result<Tensor> { - // TODO - Ok(x.clone()) - } -} - -pub type Embedding = candle_nn::Embedding; - -pub fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} - -pub type Conv1d = candle_nn::Conv1d; -pub type Conv1dConfig = candle_nn::Conv1dConfig; +use candle::Result; +use candle_nn::{Conv1d, Conv1dConfig, VarBuilder}; // Applies weight norm for inference by recomputing the weight tensor. This // does not apply to training. @@ -75,17 +18,3 @@ pub fn conv1d_weight_norm( let bias = vb.get(out_c, "bias")?; Ok(Conv1d::new(weight, Some(bias), config)) } - -pub fn conv1d( - in_c: usize, - out_c: usize, - kernel_size: usize, - config: Conv1dConfig, - vb: VarBuilder, -) -> Result<Conv1d> { - let weight = vb.get((out_c, in_c, kernel_size), "weight")?; - let bias = vb.get(out_c, "bias")?; - Ok(Conv1d::new(weight, Some(bias), config)) -} - -pub type HiddenAct = candle_nn::Activation; |