summaryrefslogtreecommitdiff
path: root/candle-examples/examples/musicgen/nn.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/musicgen/nn.rs')
-rw-r--r--candle-examples/examples/musicgen/nn.rs75
1 files changed, 2 insertions, 73 deletions
diff --git a/candle-examples/examples/musicgen/nn.rs b/candle-examples/examples/musicgen/nn.rs
index 31b1a162..282b3a05 100644
--- a/candle-examples/examples/musicgen/nn.rs
+++ b/candle-examples/examples/musicgen/nn.rs
@@ -1,62 +1,5 @@
-use anyhow::Result;
-use candle::Tensor;
-
-const MAX_SEQ_LEN: usize = 5000;
-
-pub type VarBuilder<'a> = candle_nn::VarBuilder<'a>;
-pub type Linear = candle_nn::Linear;
-
-pub fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
- let weight = vb.get((size2, size1), "weight")?;
- let bias = if bias {
- Some(vb.get(size2, "bias")?)
- } else {
- None
- };
- Ok(Linear::new(weight, bias))
-}
-
-pub type LayerNorm = candle_nn::LayerNorm;
-
-pub fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
- let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
- (Ok(weight), Ok(bias)) => (weight, bias),
- (Err(err), _) | (_, Err(err)) => {
- if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
- (weight, bias)
- } else {
- return Err(err.into());
- }
- }
- };
- Ok(LayerNorm::new(weight, bias, eps))
-}
-
-#[derive(Debug)]
-pub struct Dropout {
- pr: f64,
-}
-
-impl Dropout {
- pub fn new(pr: f64) -> Self {
- Self { pr }
- }
-
- pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
- // TODO
- Ok(x.clone())
- }
-}
-
-pub type Embedding = candle_nn::Embedding;
-
-pub fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
- let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
- Ok(Embedding::new(embeddings, hidden_size))
-}
-
-pub type Conv1d = candle_nn::Conv1d;
-pub type Conv1dConfig = candle_nn::Conv1dConfig;
+use candle::Result;
+use candle_nn::{Conv1d, Conv1dConfig, VarBuilder};
// Applies weight norm for inference by recomputing the weight tensor. This
// does not apply to training.
@@ -75,17 +18,3 @@ pub fn conv1d_weight_norm(
let bias = vb.get(out_c, "bias")?;
Ok(Conv1d::new(weight, Some(bias), config))
}
-
-pub fn conv1d(
- in_c: usize,
- out_c: usize,
- kernel_size: usize,
- config: Conv1dConfig,
- vb: VarBuilder,
-) -> Result<Conv1d> {
- let weight = vb.get((out_c, in_c, kernel_size), "weight")?;
- let bias = vb.get(out_c, "bias")?;
- Ok(Conv1d::new(weight, Some(bias), config))
-}
-
-pub type HiddenAct = candle_nn::Activation;