diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-03 15:50:39 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-03 14:50:39 +0100 |
commit | bbec527bb966b5050a9f8a3fe1382ea929e39d41 (patch) | |
tree | 7e8f7a22d71b658fe5647c22f3f2289711457c0d /candle-examples/examples/musicgen/encodec_model.rs | |
parent | f7980e07e073601a95b615ce8cb7008934dcb235 (diff) | |
download | candle-bbec527bb966b5050a9f8a3fe1382ea929e39d41.tar.gz candle-bbec527bb966b5050a9f8a3fe1382ea929e39d41.tar.bz2 candle-bbec527bb966b5050a9f8a3fe1382ea929e39d41.zip |
Fix the musicgen example. (#724)
* Fix the musicgen example.
* Retrieve the weights from the hub.
Diffstat (limited to 'candle-examples/examples/musicgen/encodec_model.rs')
-rw-r--r-- | candle-examples/examples/musicgen/encodec_model.rs | 11 |
1 files changed, 5 insertions, 6 deletions
diff --git a/candle-examples/examples/musicgen/encodec_model.rs b/candle-examples/examples/musicgen/encodec_model.rs index 53b252ed..bf33d49d 100644 --- a/candle-examples/examples/musicgen/encodec_model.rs +++ b/candle-examples/examples/musicgen/encodec_model.rs @@ -1,7 +1,6 @@ -use crate::nn::{conv1d, conv1d_weight_norm, Conv1d, Conv1dConfig, VarBuilder}; -use anyhow::Result; -use candle::{DType, IndexOp, Tensor}; -use candle_nn::Module; +use crate::nn::conv1d_weight_norm; +use candle::{DType, IndexOp, Result, Tensor}; +use candle_nn::{conv1d, Conv1d, Conv1dConfig, Module, VarBuilder}; // Encodec Model // https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py @@ -183,7 +182,7 @@ impl EncodecResidualVectorQuantizer { fn decode(&self, codes: &Tensor) -> Result<Tensor> { let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?; if codes.dim(0)? != self.layers.len() { - anyhow::bail!( + candle::bail!( "codes shape {:?} does not match the number of quantization layers {}", codes.shape(), self.layers.len() @@ -321,7 +320,7 @@ impl EncodecResnetBlock { let h = dim / cfg.compress; let mut layer = Layer::new(vb.pp("block")); if dilations.len() != 2 { - anyhow::bail!("expected dilations of size 2") + candle::bail!("expected dilations of size 2") } // TODO: Apply dilations! layer.inc(); |