summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/musicgen/encodec_model.rs75
-rw-r--r--candle-examples/examples/musicgen/t5_model.rs5
2 files changed, 76 insertions, 4 deletions
diff --git a/candle-examples/examples/musicgen/encodec_model.rs b/candle-examples/examples/musicgen/encodec_model.rs
index ed8a66b7..f9b883fe 100644
--- a/candle-examples/examples/musicgen/encodec_model.rs
+++ b/candle-examples/examples/musicgen/encodec_model.rs
@@ -1,6 +1,6 @@
use crate::nn::{conv1d, conv1d_weight_norm, Conv1d, Conv1dConfig, VarBuilder};
use anyhow::Result;
-use candle::Tensor;
+use candle::{DType, IndexOp, Tensor};
// Encodec Model
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
@@ -140,6 +140,11 @@ impl EncodecEuclideanCodebook {
embed_avg,
})
}
+
+ fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
+ let quantize = Tensor::embedding(embed_ind, &self.embed)?;
+ Ok(quantize)
+ }
}
#[derive(Debug)]
@@ -152,6 +157,12 @@ impl EncodecVectorQuantization {
let codebook = EncodecEuclideanCodebook::load(vb.pp("codebook"), cfg)?;
Ok(Self { codebook })
}
+
+ fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
+ let quantize = self.codebook.decode(embed_ind)?;
+ let quantize = quantize.transpose(1, 2)?;
+ Ok(quantize)
+ }
}
#[derive(Debug)]
@@ -167,6 +178,22 @@ impl EncodecResidualVectorQuantizer {
.collect::<Result<Vec<_>>>()?;
Ok(Self { layers })
}
+
+ fn decode(&self, codes: &Tensor) -> Result<Tensor> {
+ let mut quantized_out = Tensor::zeros((), DType::F32, &codes.device())?;
+ if codes.dim(0)? != self.layers.len() {
+ anyhow::bail!(
+ "codes shape {:?} does not match the number of quantization layers {}",
+ codes.shape(),
+ self.layers.len()
+ )
+ }
+ for (i, layer) in self.layers.iter().enumerate() {
+ let quantized = layer.decode(&codes.i(i)?)?;
+ quantized_out = quantized.broadcast_add(&quantized_out)?;
+ }
+ Ok(quantized_out)
+ }
}
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
@@ -188,6 +215,10 @@ impl EncodecLSTM {
}
Ok(Self { layers })
}
+
+ fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
+ todo!()
+ }
}
#[derive(Debug)]
@@ -216,10 +247,15 @@ impl EncodecConvTranspose1d {
bias,
})
}
+
+ fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
+ todo!()
+ }
}
#[derive(Debug)]
struct EncodecConv1d {
+ causal: bool,
conv: Conv1d,
}
@@ -248,7 +284,17 @@ impl EncodecConv1d {
vb.pp("conv"),
)?,
};
- Ok(Self { conv })
+ Ok(Self {
+ causal: cfg.use_causal_conv,
+ conv,
+ })
+ }
+
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ // TODO: padding, depending on causal.
+ let xs = self.conv.forward(xs)?;
+ // If we add support for NormType "time_group_norm", we should add some normalization here.
+ Ok(xs)
}
}
@@ -284,6 +330,19 @@ impl EncodecResnetBlock {
shortcut,
})
}
+
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let residual = xs.clone();
+ let xs = xs.elu(1.)?;
+ let xs = self.block_conv1.forward(&xs)?;
+ let xs = xs.elu(1.)?;
+ let xs = self.block_conv2.forward(&xs)?;
+ let xs = match &self.shortcut {
+ None => (xs + residual)?,
+ Some(shortcut) => xs.add(&shortcut.forward(&residual)?)?,
+ };
+ Ok(xs)
+ }
}
struct Layer<'a> {
@@ -369,6 +428,10 @@ impl EncodecEncoder {
final_lstm,
})
}
+
+ fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
+ todo!()
+ }
}
#[derive(Debug)]
@@ -433,6 +496,10 @@ impl EncodecDecoder {
final_conv,
})
}
+
+ fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
+ todo!()
+ }
}
#[derive(Debug)]
@@ -453,4 +520,8 @@ impl EncodecModel {
quantizer,
})
}
+
+ pub fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
+ todo!()
+ }
}
diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs
index 23bf7f0d..0444f360 100644
--- a/candle-examples/examples/musicgen/t5_model.rs
+++ b/candle-examples/examples/musicgen/t5_model.rs
@@ -206,6 +206,8 @@ impl T5Attention {
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ // TODO: Apply the mask(s)?
+ // TODO: kv caching.
let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?);
let q = self.q.forward(xs)?;
let k = self.k.forward(xs)?;
@@ -220,7 +222,7 @@ impl T5Attention {
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?;
let scores = q.matmul(&k.t()?)?;
- // position_bias_masked
+ // TODO: position_bias_masked
let attn_weights = scores.softmax(D::Minus1)?;
let attn_output = attn_weights.matmul(&v)?;
let attn_output = self.o.forward(&attn_output)?;
@@ -309,7 +311,6 @@ impl T5Block {
#[derive(Debug)]
struct T5Stack {
- // TODO: Add embed_tokens if needed (shared embedding layer).
block: Vec<T5Block>,
shared: Arc<Embedding>,
final_layer_norm: T5LayerNorm,