summaryrefslogtreecommitdiff
path: root/candle-examples/examples/musicgen/musicgen_model.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/musicgen/musicgen_model.rs')
-rw-r--r--candle-examples/examples/musicgen/musicgen_model.rs8
1 files changed, 4 insertions, 4 deletions
diff --git a/candle-examples/examples/musicgen/musicgen_model.rs b/candle-examples/examples/musicgen/musicgen_model.rs
index 3c5e66f8..212f6818 100644
--- a/candle-examples/examples/musicgen/musicgen_model.rs
+++ b/candle-examples/examples/musicgen/musicgen_model.rs
@@ -123,7 +123,7 @@ impl MusicgenSinusoidalPositionalEmbedding {
}
fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
- let (_b_sz, _codebooks, seq_len) = input_ids.shape().r3()?;
+ let (_b_sz, _codebooks, seq_len) = input_ids.dims3()?;
if seq_len > self.weights.dim(0)? {
self.weights = get_embedding(seq_len, self.embedding_dim)?
}
@@ -170,7 +170,7 @@ impl MusicgenAttention {
kv_states: Option<&Tensor>,
attention_mask: &Tensor,
) -> Result<Tensor> {
- let (b_sz, tgt_len, _) = xs.shape().r3()?;
+ let (b_sz, tgt_len, _) = xs.dims3()?;
let query_states = (self.q_proj.forward(xs)? * self.scaling)?;
let kv_states = kv_states.unwrap_or(xs);
@@ -308,7 +308,7 @@ impl MusicgenDecoder {
fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let dev = input_ids.device();
- let (b_sz_times_codebooks, seq_len) = input_ids.shape().r2()?;
+ let (b_sz_times_codebooks, seq_len) = input_ids.dims2()?;
let b_sz = b_sz_times_codebooks / self.num_codebooks;
let input = input_ids.reshape((b_sz, self.num_codebooks, seq_len))?;
let mut inputs_embeds = Tensor::zeros((b_sz, seq_len, self.d_model), DType::F32, dev)?;
@@ -352,7 +352,7 @@ impl MusicgenForCausalLM {
}
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
- let (b_sz, seq_len) = input_ids.shape().r2()?;
+ let (b_sz, seq_len) = input_ids.dims2()?;
let hidden_states = self.decoder.forward(input_ids)?;
let lm_logits = self
.lm_heads