diff options
Diffstat (limited to 'candle-examples/examples/musicgen/musicgen_model.rs')
-rw-r--r-- | candle-examples/examples/musicgen/musicgen_model.rs | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/candle-examples/examples/musicgen/musicgen_model.rs b/candle-examples/examples/musicgen/musicgen_model.rs index 3c5e66f8..212f6818 100644 --- a/candle-examples/examples/musicgen/musicgen_model.rs +++ b/candle-examples/examples/musicgen/musicgen_model.rs @@ -123,7 +123,7 @@ impl MusicgenSinusoidalPositionalEmbedding { } fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> { - let (_b_sz, _codebooks, seq_len) = input_ids.shape().r3()?; + let (_b_sz, _codebooks, seq_len) = input_ids.dims3()?; if seq_len > self.weights.dim(0)? { self.weights = get_embedding(seq_len, self.embedding_dim)? } @@ -170,7 +170,7 @@ impl MusicgenAttention { kv_states: Option<&Tensor>, attention_mask: &Tensor, ) -> Result<Tensor> { - let (b_sz, tgt_len, _) = xs.shape().r3()?; + let (b_sz, tgt_len, _) = xs.dims3()?; let query_states = (self.q_proj.forward(xs)? * self.scaling)?; let kv_states = kv_states.unwrap_or(xs); @@ -308,7 +308,7 @@ impl MusicgenDecoder { fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> { let dev = input_ids.device(); - let (b_sz_times_codebooks, seq_len) = input_ids.shape().r2()?; + let (b_sz_times_codebooks, seq_len) = input_ids.dims2()?; let b_sz = b_sz_times_codebooks / self.num_codebooks; let input = input_ids.reshape((b_sz, self.num_codebooks, seq_len))?; let mut inputs_embeds = Tensor::zeros((b_sz, seq_len, self.d_model), DType::F32, dev)?; @@ -352,7 +352,7 @@ impl MusicgenForCausalLM { } pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> { - let (b_sz, seq_len) = input_ids.shape().r2()?; + let (b_sz, seq_len) = input_ids.dims2()?; let hidden_states = self.decoder.forward(input_ids)?; let lm_logits = self .lm_heads |