summaryrefslogtreecommitdiff
path: root/candle-examples/examples/whisper
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/whisper')
-rw-r--r--candle-examples/examples/whisper/main.rs4
-rw-r--r--candle-examples/examples/whisper/model.rs6
2 files changed, 5 insertions, 5 deletions
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index d7b303cf..079424e3 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -127,7 +127,7 @@ impl Decoder {
.to_scalar::<f32>()? as f64;
}
- let (seq_len, _) = logits.shape().r2()?;
+ let (seq_len, _) = logits.dims2()?;
let logits = logits
.get(seq_len - 1)?
.broadcast_add(&self.suppress_tokens)?;
@@ -195,7 +195,7 @@ impl Decoder {
}
fn run(&mut self, mel: &Tensor) -> Result<Vec<Segment>> {
- let (_, _, content_frames) = mel.shape().r3()?;
+ let (_, _, content_frames) = mel.dims3()?;
let mut seek = 0;
let mut segments = vec![];
while seek < content_frames {
diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs
index d4553e79..330b2a00 100644
--- a/candle-examples/examples/whisper/model.rs
+++ b/candle-examples/examples/whisper/model.rs
@@ -132,7 +132,7 @@ impl MultiHeadAttention {
}
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
- let (n_batch, n_ctx, n_state) = x.shape().r3()?;
+ let (n_batch, n_ctx, n_state) = x.dims3()?;
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
Ok(x.reshape(target_dims)?.transpose(1, 2)?)
}
@@ -144,7 +144,7 @@ impl MultiHeadAttention {
v: &Tensor,
mask: Option<&Tensor>,
) -> Result<Tensor> {
- let (_, n_ctx, n_state) = q.shape().r3()?;
+ let (_, n_ctx, n_state) = q.dims3()?;
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
let q = (self.reshape_head(q)? * scale)?;
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
@@ -270,7 +270,7 @@ impl AudioEncoder {
let x = self.conv1.forward(x)?.gelu()?;
let x = self.conv2.forward(&x)?.gelu()?;
let x = x.transpose(1, 2)?;
- let (_bsize, seq_len, _hidden) = x.shape().r3()?;
+ let (_bsize, seq_len, _hidden) = x.dims3()?;
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
let mut x = x.broadcast_add(&positional_embedding)?;
for block in self.blocks.iter() {