diff options
-rw-r--r-- | candle-core/src/cpu_backend.rs | 2 | ||||
-rw-r--r-- | candle-core/src/cuda_backend.rs | 2 | ||||
-rw-r--r-- | candle-core/src/shape.rs | 18 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 12 | ||||
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 10 | ||||
-rw-r--r-- | candle-examples/examples/bert/main.rs | 2 | ||||
-rw-r--r-- | candle-examples/examples/bert/model.rs | 4 | ||||
-rw-r--r-- | candle-examples/examples/falcon/model.rs | 8 | ||||
-rw-r--r-- | candle-examples/examples/llama/model.rs | 12 | ||||
-rw-r--r-- | candle-examples/examples/musicgen/musicgen_model.rs | 8 | ||||
-rw-r--r-- | candle-examples/examples/musicgen/t5_model.rs | 2 | ||||
-rw-r--r-- | candle-examples/examples/simple-training/main.rs | 2 | ||||
-rw-r--r-- | candle-examples/examples/whisper/main.rs | 4 | ||||
-rw-r--r-- | candle-examples/examples/whisper/model.rs | 6 | ||||
-rw-r--r-- | candle-nn/src/conv.rs | 2 | ||||
-rw-r--r-- | candle-nn/src/layer_norm.rs | 2 | ||||
-rw-r--r-- | candle-wasm-example/src/model.rs | 6 | ||||
-rw-r--r-- | candle-wasm-example/src/worker.rs | 4 |
18 files changed, 56 insertions, 50 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index b8d52c95..82e1f3e2 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1688,7 +1688,7 @@ impl BackendStorage for CpuStorage { fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> { let ids = self.as_slice::<u32>()?; - let (vocab_size, hidden_size) = rhs_l.shape().r2()?; + let (vocab_size, hidden_size) = rhs_l.shape().dims2()?; Embedding { vocab_size, hidden_size, diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 43bfef2d..f9fefe17 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -620,7 +620,7 @@ impl<'a> Map1 for Embedding<'a> { let shape = ids_l.shape(); let (v_size, h_size) = rhs_l .shape() - .r2() + .dims2() .map_err(|e| CudaError::WrappedError(Box::new(e))) .w()?; let dims = shape.dims(); diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index 982f9db0..b016ead5 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -87,6 +87,12 @@ macro_rules! extract_dims { } } } + impl crate::Tensor { + pub fn $fn_name(&self) -> Result<$out_type> { + self.shape().$fn_name() + } + } + impl std::convert::TryInto<$out_type> for Shape { type Error = crate::Error; fn try_into(self) -> std::result::Result<$out_type, Self::Error> { @@ -328,23 +334,23 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) { } } -extract_dims!(r0, 0, |_: &Vec<usize>| (), ()); -extract_dims!(r1, 1, |d: &[usize]| d[0], usize); -extract_dims!(r2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize)); +extract_dims!(dims0, 0, |_: &Vec<usize>| (), ()); +extract_dims!(dims1, 1, |d: &[usize]| d[0], usize); +extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize)); extract_dims!( - r3, + dims3, 3, |d: &[usize]| (d[0], d[1], d[2]), (usize, usize, usize) ); extract_dims!( - r4, + dims4, 4, |d: &[usize]| (d[0], d[1], d[2], d[3]), (usize, usize, usize, usize) ); extract_dims!( - r5, + dims5, 5, |d: &[usize]| (d[0], d[1], d[2], d[3], d[4]), (usize, usize, usize, usize, usize) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 8ba0ba43..561f1863 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -772,7 +772,7 @@ impl Tensor { /// Applies a 1D convolution over the input tensor. pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> { - let (c_out, c_in_k, k_size) = kernel.shape().r3()?; + let (c_out, c_in_k, k_size) = kernel.dims3()?; let (b_size, c_in, l_in) = match *self.dims() { [b_size, c_in, l_in] => (Some(b_size), c_in, l_in), [c_in, l_in] => (None, c_in, l_in), @@ -931,8 +931,8 @@ impl Tensor { .bt())? } let ids_shape = ids.shape(); - let seq_len = ids_shape.r1()?; - let (_, hidden_size) = rhs.shape().r2()?; + let seq_len = ids_shape.dims1()?; + let (_, hidden_size) = rhs.dims2()?; let storage = ids .storage() .embedding(ids.layout(), &rhs.storage(), rhs.layout())?; @@ -1013,7 +1013,7 @@ impl Tensor { // The number of element in indexes must match the dimension on which the add is // performed on the source tensor (and the index values from `indexes` are taken from // the target tensor self) - mismatch || source_dims[dim] != indexes.shape().r1()? + mismatch || source_dims[dim] != indexes.dims1()? }; if mismatch { Err(Error::ShapeMismatchBinaryOp { @@ -1144,7 +1144,7 @@ impl Tensor { /// Returns the data contained in a 2D tensor as a vector of vector of scalar values. pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> { - let (dim1, dim2) = self.shape().r2()?; + let (dim1, dim2) = self.dims2()?; let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { let data = S::cpu_storage_as_slice(cpu_storage)?; let mut rows = vec![]; @@ -1164,7 +1164,7 @@ impl Tensor { /// Returns the data contained in a 3D tensor. pub fn to_vec3<S: crate::WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>> { - let (dim1, dim2, dim3) = self.shape().r3()?; + let (dim1, dim2, dim3) = self.dims3()?; let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { let data = S::cpu_storage_as_slice(cpu_storage)?; let mut top_rows = vec![]; diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 6415fcb3..a126d634 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -4,7 +4,7 @@ use test_utils::to_vec3_round; fn zeros(device: &Device) -> Result<()> { let tensor = Tensor::zeros((5, 2), DType::F32, device)?; - let (dim1, dim2) = tensor.shape().r2()?; + let (dim1, dim2) = tensor.dims2()?; assert_eq!(dim1, 5); assert_eq!(dim2, 2); Ok(()) @@ -12,7 +12,7 @@ fn zeros(device: &Device) -> Result<()> { fn add_mul(device: &Device) -> Result<()> { let tensor = Tensor::new(&[3f32, 1., 4.], device)?; - let dim1 = tensor.shape().r1()?; + let dim1 = tensor.dims1()?; assert_eq!(dim1, 3); let content: Vec<f32> = tensor.to_vec1()?; assert_eq!(content, [3., 1., 4.]); @@ -28,7 +28,7 @@ fn add_mul(device: &Device) -> Result<()> { fn tensor_2d(device: &Device) -> Result<()> { let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]]; let tensor = Tensor::new(data, device)?; - let dims = tensor.shape().r2()?; + let dims = tensor.dims2()?; assert_eq!(dims, (2, 5)); let content: Vec<Vec<f32>> = tensor.to_vec2()?; assert_eq!(content, data); @@ -41,7 +41,7 @@ fn binary_op(device: &Device) -> Result<()> { let data2 = &[[5f32, 5., 5., 5., 5.], [2., 1., 7., 8., 2.]]; let tensor2 = Tensor::new(data2, device)?; let tensor = (&tensor + (&tensor * &tensor)? / (&tensor + &tensor2))?; - let dims = tensor.shape().r2()?; + let dims = tensor.dims2()?; assert_eq!(dims, (2, 5)); let content: Vec<Vec<f32>> = tensor.to_vec2()?; assert_eq!(content[0], [4.125, 1.1666666, 5.7777777, 1.1666666, 7.5]); @@ -56,7 +56,7 @@ fn binary_op(device: &Device) -> Result<()> { fn transpose(device: &Device) -> Result<()> { let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]]; let tensor = Tensor::new(data, device)?.t()?; - let dims = tensor.shape().r2()?; + let dims = tensor.dims2()?; assert_eq!(dims, (5, 2)); assert_eq!( tensor.to_vec2::<f32>()?, diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 33f0a1fe..6672ad09 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -161,7 +161,7 @@ fn main() -> Result<()> { let embeddings = model.forward(&token_ids, &token_type_ids)?; println!("generated embeddings {:?}", embeddings.shape()); // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) - let (_n_sentence, n_tokens, _hidden_size) = embeddings.shape().r3()?; + let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?; let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?; println!("pooled embeddings {:?}", embeddings.shape()); let mut similarities = vec![]; diff --git a/candle-examples/examples/bert/model.rs b/candle-examples/examples/bert/model.rs index fa0e8c76..3bf412b2 100644 --- a/candle-examples/examples/bert/model.rs +++ b/candle-examples/examples/bert/model.rs @@ -87,7 +87,7 @@ impl LayerNorm { DType::F16 | DType::BF16 => DType::F32, d => d, }; - let (_bsize, _seq_len, hidden_size) = x.shape().r3()?; + let (_bsize, _seq_len, hidden_size) = x.dims3()?; let x = x.to_dtype(internal_dtype)?; let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?; let x = x.broadcast_sub(&mean_x)?; @@ -262,7 +262,7 @@ impl BertEmbeddings { fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> { let _enter = self.span.enter(); - let (_bsize, seq_len) = input_ids.shape().r2()?; + let (_bsize, seq_len) = input_ids.dims2()?; let input_embeddings = self.word_embeddings.forward(input_ids)?; let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?; let mut embeddings = (&input_embeddings + token_type_embeddings)?; diff --git a/candle-examples/examples/falcon/model.rs b/candle-examples/examples/falcon/model.rs index 60821add..bce93c81 100644 --- a/candle-examples/examples/falcon/model.rs +++ b/candle-examples/examples/falcon/model.rs @@ -182,7 +182,7 @@ impl FalconRotaryEmbedding { key: &Tensor, past_kv_len: usize, ) -> Result<(Tensor, Tensor)> { - let (_batch, seq_len, _head_dim) = query.shape().r3()?; + let (_batch, seq_len, _head_dim) = query.dims3()?; let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, query.device(), query.dtype())?; let cos = cos.narrow(0, past_kv_len, seq_len)?; let sin = sin.narrow(0, past_kv_len, seq_len)?; @@ -245,7 +245,7 @@ impl FalconAttention { } fn split_heads(&self, fused_qkv: &Tensor) -> Result<(Tensor, Tensor, Tensor)> { - let (b_sz, seq_len, _) = fused_qkv.shape().r3()?; + let (b_sz, seq_len, _) = fused_qkv.dims3()?; if !self.multi_query { let fused_qkv = fused_qkv.reshape((b_sz, seq_len, self.num_heads, 3, self.head_dim))?; let q = fused_qkv.narrow(D::Minus2, 0, 1)?.squeeze(D::Minus2)?; @@ -267,7 +267,7 @@ impl FalconAttention { let fused_qkv = self.query_key_value.forward(x)?; let head_dim = self.head_dim; let (query, key, value) = self.split_heads(&fused_qkv)?; - let (b_sz, seq_len, _, _) = query.shape().r4()?; + let (b_sz, seq_len, _, _) = query.dims4()?; let query = query .transpose(1, 2)? .reshape((b_sz * self.num_heads, seq_len, head_dim))?; @@ -465,7 +465,7 @@ impl Falcon { } pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> { - let (b_sz, seq_len) = input_ids.shape().r2()?; + let (b_sz, seq_len) = input_ids.dims2()?; let mut hidden_state = self.word_embeddings.forward(input_ids)?; let past_kv_len = match &self.blocks[0].self_attention.kv_cache { Some((k, _)) => k.dim(1)?, diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index f3e30ec9..b074e5cb 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -116,11 +116,11 @@ impl RmsNorm { let in_dtype = x.dtype(); // This is a no-op if x's dtype is already f32. let x = x.to_dtype(DType::F32)?; - let (b_sz, seq_len, hidden_size) = x.shape().r3()?; + let (b_sz, seq_len, hidden_size) = x.dims3()?; let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?; let x_normed = (x / (norm_x + 1e-6)?.sqrt()?)?; - let size = self.scale.shape().r1()?; + let size = self.scale.dims1()?; let scale = self .scale .to_dtype(DType::F32)? @@ -144,7 +144,7 @@ struct CausalSelfAttention { impl CausalSelfAttention { fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { - let (b_sz, _, seq_len, n_embd) = x.shape().r4()?; + let (b_sz, _, seq_len, n_embd) = x.dims4()?; let cos = self.cache.cos.narrow(0, index_pos, seq_len)?; let sin = self.cache.sin.narrow(0, index_pos, seq_len)?; let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?; @@ -158,7 +158,7 @@ impl CausalSelfAttention { fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> { let x_dtype = x.dtype(); - let (b_sz, seq_len, n_embd) = x.shape().r3()?; + let (b_sz, seq_len, n_embd) = x.dims3()?; let q = self.q_proj.forward(x)?; let k = self.k_proj.forward(x)?; let v = self.v_proj.forward(x)?; @@ -219,7 +219,7 @@ impl CausalSelfAttention { if n_rep == 1 { Ok(x) } else { - let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().r4()?; + let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?; let x = x .unsqueeze(2)? .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))? @@ -345,7 +345,7 @@ impl Llama { } pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { - let (_b_sz, seq_len) = x.shape().r2()?; + let (_b_sz, seq_len) = x.dims2()?; let mut x = self.wte.forward(x)?; for (block_idx, block) in self.blocks.iter().enumerate() { x = block.forward(&x, index_pos, block_idx)?; diff --git a/candle-examples/examples/musicgen/musicgen_model.rs b/candle-examples/examples/musicgen/musicgen_model.rs index 3c5e66f8..212f6818 100644 --- a/candle-examples/examples/musicgen/musicgen_model.rs +++ b/candle-examples/examples/musicgen/musicgen_model.rs @@ -123,7 +123,7 @@ impl MusicgenSinusoidalPositionalEmbedding { } fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> { - let (_b_sz, _codebooks, seq_len) = input_ids.shape().r3()?; + let (_b_sz, _codebooks, seq_len) = input_ids.dims3()?; if seq_len > self.weights.dim(0)? { self.weights = get_embedding(seq_len, self.embedding_dim)? } @@ -170,7 +170,7 @@ impl MusicgenAttention { kv_states: Option<&Tensor>, attention_mask: &Tensor, ) -> Result<Tensor> { - let (b_sz, tgt_len, _) = xs.shape().r3()?; + let (b_sz, tgt_len, _) = xs.dims3()?; let query_states = (self.q_proj.forward(xs)? * self.scaling)?; let kv_states = kv_states.unwrap_or(xs); @@ -308,7 +308,7 @@ impl MusicgenDecoder { fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> { let dev = input_ids.device(); - let (b_sz_times_codebooks, seq_len) = input_ids.shape().r2()?; + let (b_sz_times_codebooks, seq_len) = input_ids.dims2()?; let b_sz = b_sz_times_codebooks / self.num_codebooks; let input = input_ids.reshape((b_sz, self.num_codebooks, seq_len))?; let mut inputs_embeds = Tensor::zeros((b_sz, seq_len, self.d_model), DType::F32, dev)?; @@ -352,7 +352,7 @@ impl MusicgenForCausalLM { } pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> { - let (b_sz, seq_len) = input_ids.shape().r2()?; + let (b_sz, seq_len) = input_ids.dims2()?; let hidden_states = self.decoder.forward(input_ids)?; let lm_logits = self .lm_heads diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs index 15945d4e..61c0a1bb 100644 --- a/candle-examples/examples/musicgen/t5_model.rs +++ b/candle-examples/examples/musicgen/t5_model.rs @@ -338,7 +338,7 @@ impl T5Stack { fn forward(&self, input_ids: &Tensor) -> Result<Tensor> { let input_embeds = self.shared.as_ref().forward(input_ids)?; - let (_b_sz, _seq_len) = input_embeds.shape().r2()?; + let (_b_sz, _seq_len) = input_embeds.dims2()?; let mut hidden_states = self.dropout.forward(&input_embeds)?; for block in self.block.iter() { diff --git a/candle-examples/examples/simple-training/main.rs b/candle-examples/examples/simple-training/main.rs index 2cfe4923..60f2281b 100644 --- a/candle-examples/examples/simple-training/main.rs +++ b/candle-examples/examples/simple-training/main.rs @@ -52,7 +52,7 @@ pub fn main() -> Result<()> { .to_dtype(DType::F32)? .sum_all()? .to_scalar::<f32>()?; - let test_accuracy = sum_ok / test_labels.shape().r1()? as f32; + let test_accuracy = sum_ok / test_labels.dims1()? as f32; println!( "{epoch:4} train loss: {:8.5} test acc: {:5.2}%", loss.to_scalar::<f32>()?, diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index d7b303cf..079424e3 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -127,7 +127,7 @@ impl Decoder { .to_scalar::<f32>()? as f64; } - let (seq_len, _) = logits.shape().r2()?; + let (seq_len, _) = logits.dims2()?; let logits = logits .get(seq_len - 1)? .broadcast_add(&self.suppress_tokens)?; @@ -195,7 +195,7 @@ impl Decoder { } fn run(&mut self, mel: &Tensor) -> Result<Vec<Segment>> { - let (_, _, content_frames) = mel.shape().r3()?; + let (_, _, content_frames) = mel.dims3()?; let mut seek = 0; let mut segments = vec![]; while seek < content_frames { diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs index d4553e79..330b2a00 100644 --- a/candle-examples/examples/whisper/model.rs +++ b/candle-examples/examples/whisper/model.rs @@ -132,7 +132,7 @@ impl MultiHeadAttention { } fn reshape_head(&self, x: &Tensor) -> Result<Tensor> { - let (n_batch, n_ctx, n_state) = x.shape().r3()?; + let (n_batch, n_ctx, n_state) = x.dims3()?; let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head]; Ok(x.reshape(target_dims)?.transpose(1, 2)?) } @@ -144,7 +144,7 @@ impl MultiHeadAttention { v: &Tensor, mask: Option<&Tensor>, ) -> Result<Tensor> { - let (_, n_ctx, n_state) = q.shape().r3()?; + let (_, n_ctx, n_state) = q.dims3()?; let scale = ((n_state / self.n_head) as f64).powf(-0.25); let q = (self.reshape_head(q)? * scale)?; let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?; @@ -270,7 +270,7 @@ impl AudioEncoder { let x = self.conv1.forward(x)?.gelu()?; let x = self.conv2.forward(&x)?.gelu()?; let x = x.transpose(1, 2)?; - let (_bsize, seq_len, _hidden) = x.shape().r3()?; + let (_bsize, seq_len, _hidden) = x.dims3()?; let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?; let mut x = x.broadcast_add(&positional_embedding)?; for block in self.blocks.iter() { diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index d938cae4..8fbe7659 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -41,7 +41,7 @@ impl Conv1d { match &self.bias { None => Ok(x), Some(bias) => { - let b = bias.shape().r1()?; + let b = bias.dims1()?; let bias = bias.reshape((1, b, 1))?; Ok(x.broadcast_add(&bias)?) } diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs index 88d5ab32..8f8544bb 100644 --- a/candle-nn/src/layer_norm.rs +++ b/candle-nn/src/layer_norm.rs @@ -49,7 +49,7 @@ impl LayerNorm { DType::F16 | DType::BF16 => DType::F32, d => d, }; - let (_bsize, _seq_len, hidden_size) = x.shape().r3()?; + let (_bsize, _seq_len, hidden_size) = x.dims3()?; let x = x.to_dtype(internal_dtype)?; let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?; let x = x.broadcast_sub(&mean_x)?; diff --git a/candle-wasm-example/src/model.rs b/candle-wasm-example/src/model.rs index 89c0d708..97eff839 100644 --- a/candle-wasm-example/src/model.rs +++ b/candle-wasm-example/src/model.rs @@ -164,7 +164,7 @@ impl MultiHeadAttention { } fn reshape_head(&self, x: &Tensor) -> Result<Tensor> { - let (n_batch, n_ctx, n_state) = x.shape().r3()?; + let (n_batch, n_ctx, n_state) = x.dims3()?; let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head]; Ok(x.reshape(target_dims)?.transpose(1, 2)?) } @@ -176,7 +176,7 @@ impl MultiHeadAttention { v: &Tensor, mask: Option<&Tensor>, ) -> Result<Tensor> { - let (_, n_ctx, n_state) = q.shape().r3()?; + let (_, n_ctx, n_state) = q.dims3()?; let scale = ((n_state / self.n_head) as f64).powf(-0.25); let q = { let _timer = crate::Timer::new("q::reshape"); @@ -328,7 +328,7 @@ impl AudioEncoder { self.conv2.forward(&x)?.gelu()? }; let x = x.transpose(1, 2)?; - let (_bsize, seq_len, _hidden) = x.shape().r3()?; + let (_bsize, seq_len, _hidden) = x.dims3()?; let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?; let mut x = x.broadcast_add(&positional_embedding)?; for block in self.blocks.iter() { diff --git a/candle-wasm-example/src/worker.rs b/candle-wasm-example/src/worker.rs index 5001e7e4..ea64bf02 100644 --- a/candle-wasm-example/src/worker.rs +++ b/candle-wasm-example/src/worker.rs @@ -134,7 +134,7 @@ impl Decoder { .to_scalar::<f32>()? as f64; } - let (seq_len, _) = logits.shape().r2()?; + let (seq_len, _) = logits.dims2()?; let logits = logits .get(seq_len - 1)? .broadcast_add(&self.suppress_tokens)?; @@ -207,7 +207,7 @@ impl Decoder { fn run(&self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> { let mut rng = StdRng::seed_from_u64(299792458); - let (_, _, content_frames) = mel.shape().r3()?; + let (_, _, content_frames) = mel.dims3()?; let mut seek = 0; let mut segments = vec![]; while seek < content_frames { |