diff options
Diffstat (limited to 'candle-examples/examples/whisper')
-rw-r--r-- | candle-examples/examples/whisper/main.rs | 4 | ||||
-rw-r--r-- | candle-examples/examples/whisper/model.rs | 6 |
2 files changed, 5 insertions, 5 deletions
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index d7b303cf..079424e3 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -127,7 +127,7 @@ impl Decoder { .to_scalar::<f32>()? as f64; } - let (seq_len, _) = logits.shape().r2()?; + let (seq_len, _) = logits.dims2()?; let logits = logits .get(seq_len - 1)? .broadcast_add(&self.suppress_tokens)?; @@ -195,7 +195,7 @@ impl Decoder { } fn run(&mut self, mel: &Tensor) -> Result<Vec<Segment>> { - let (_, _, content_frames) = mel.shape().r3()?; + let (_, _, content_frames) = mel.dims3()?; let mut seek = 0; let mut segments = vec![]; while seek < content_frames { diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs index d4553e79..330b2a00 100644 --- a/candle-examples/examples/whisper/model.rs +++ b/candle-examples/examples/whisper/model.rs @@ -132,7 +132,7 @@ impl MultiHeadAttention { } fn reshape_head(&self, x: &Tensor) -> Result<Tensor> { - let (n_batch, n_ctx, n_state) = x.shape().r3()?; + let (n_batch, n_ctx, n_state) = x.dims3()?; let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head]; Ok(x.reshape(target_dims)?.transpose(1, 2)?) } @@ -144,7 +144,7 @@ impl MultiHeadAttention { v: &Tensor, mask: Option<&Tensor>, ) -> Result<Tensor> { - let (_, n_ctx, n_state) = q.shape().r3()?; + let (_, n_ctx, n_state) = q.dims3()?; let scale = ((n_state / self.n_head) as f64).powf(-0.25); let q = (self.reshape_head(q)? * scale)?; let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?; @@ -270,7 +270,7 @@ impl AudioEncoder { let x = self.conv1.forward(x)?.gelu()?; let x = self.conv2.forward(&x)?.gelu()?; let x = x.transpose(1, 2)?; - let (_bsize, seq_len, _hidden) = x.shape().r3()?; + let (_bsize, seq_len, _hidden) = x.dims3()?; let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?; let mut x = x.broadcast_add(&positional_embedding)?; for block in self.blocks.iter() { |