summaryrefslogtreecommitdiff
path: root/candle-examples/examples/musicgen/encodec_model.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-03 15:50:39 +0200
committerGitHub <noreply@github.com>2023-09-03 14:50:39 +0100
commitbbec527bb966b5050a9f8a3fe1382ea929e39d41 (patch)
tree7e8f7a22d71b658fe5647c22f3f2289711457c0d /candle-examples/examples/musicgen/encodec_model.rs
parentf7980e07e073601a95b615ce8cb7008934dcb235 (diff)
downloadcandle-bbec527bb966b5050a9f8a3fe1382ea929e39d41.tar.gz
candle-bbec527bb966b5050a9f8a3fe1382ea929e39d41.tar.bz2
candle-bbec527bb966b5050a9f8a3fe1382ea929e39d41.zip
Fix the musicgen example. (#724)
* Fix the musicgen example. * Retrieve the weights from the hub.
Diffstat (limited to 'candle-examples/examples/musicgen/encodec_model.rs')
-rw-r--r--candle-examples/examples/musicgen/encodec_model.rs11
1 files changed, 5 insertions, 6 deletions
diff --git a/candle-examples/examples/musicgen/encodec_model.rs b/candle-examples/examples/musicgen/encodec_model.rs
index 53b252ed..bf33d49d 100644
--- a/candle-examples/examples/musicgen/encodec_model.rs
+++ b/candle-examples/examples/musicgen/encodec_model.rs
@@ -1,7 +1,6 @@
-use crate::nn::{conv1d, conv1d_weight_norm, Conv1d, Conv1dConfig, VarBuilder};
-use anyhow::Result;
-use candle::{DType, IndexOp, Tensor};
-use candle_nn::Module;
+use crate::nn::conv1d_weight_norm;
+use candle::{DType, IndexOp, Result, Tensor};
+use candle_nn::{conv1d, Conv1d, Conv1dConfig, Module, VarBuilder};
// Encodec Model
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
@@ -183,7 +182,7 @@ impl EncodecResidualVectorQuantizer {
fn decode(&self, codes: &Tensor) -> Result<Tensor> {
let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?;
if codes.dim(0)? != self.layers.len() {
- anyhow::bail!(
+ candle::bail!(
"codes shape {:?} does not match the number of quantization layers {}",
codes.shape(),
self.layers.len()
@@ -321,7 +320,7 @@ impl EncodecResnetBlock {
let h = dim / cfg.compress;
let mut layer = Layer::new(vb.pp("block"));
if dilations.len() != 2 {
- anyhow::bail!("expected dilations of size 2")
+ candle::bail!("expected dilations of size 2")
}
// TODO: Apply dilations!
layer.inc();