use candle::{IndexOp, Result, Tensor, D}; use tokenizers::Tokenizer; const LANGUAGES: [(&str, &str); 99] = [ ("en", "english"), ("zh", "chinese"), ("de", "german"), ("es", "spanish"), ("ru", "russian"), ("ko", "korean"), ("fr", "french"), ("ja", "japanese"), ("pt", "portuguese"), ("tr", "turkish"), ("pl", "polish"), ("ca", "catalan"), ("nl", "dutch"), ("ar", "arabic"), ("sv", "swedish"), ("it", "italian"), ("id", "indonesian"), ("hi", "hindi"), ("fi", "finnish"), ("vi", "vietnamese"), ("he", "hebrew"), ("uk", "ukrainian"), ("el", "greek"), ("ms", "malay"), ("cs", "czech"), ("ro", "romanian"), ("da", "danish"), ("hu", "hungarian"), ("ta", "tamil"), ("no", "norwegian"), ("th", "thai"), ("ur", "urdu"), ("hr", "croatian"), ("bg", "bulgarian"), ("lt", "lithuanian"), ("la", "latin"), ("mi", "maori"), ("ml", "malayalam"), ("cy", "welsh"), ("sk", "slovak"), ("te", "telugu"), ("fa", "persian"), ("lv", "latvian"), ("bn", "bengali"), ("sr", "serbian"), ("az", "azerbaijani"), ("sl", "slovenian"), ("kn", "kannada"), ("et", "estonian"), ("mk", "macedonian"), ("br", "breton"), ("eu", "basque"), ("is", "icelandic"), ("hy", "armenian"), ("ne", "nepali"), ("mn", "mongolian"), ("bs", "bosnian"), ("kk", "kazakh"), ("sq", "albanian"), ("sw", "swahili"), ("gl", "galician"), ("mr", "marathi"), ("pa", "punjabi"), ("si", "sinhala"), ("km", "khmer"), ("sn", "shona"), ("yo", "yoruba"), ("so", "somali"), ("af", "afrikaans"), ("oc", "occitan"), ("ka", "georgian"), ("be", "belarusian"), ("tg", "tajik"), ("sd", "sindhi"), ("gu", "gujarati"), ("am", "amharic"), ("yi", "yiddish"), ("lo", "lao"), ("uz", "uzbek"), ("fo", "faroese"), ("ht", "haitian creole"), ("ps", "pashto"), ("tk", "turkmen"), ("nn", "nynorsk"), ("mt", "maltese"), ("sa", "sanskrit"), ("lb", "luxembourgish"), ("my", "myanmar"), ("bo", "tibetan"), ("tl", "tagalog"), ("mg", "malagasy"), ("as", "assamese"), ("tt", "tatar"), ("haw", "hawaiian"), ("ln", "lingala"), ("ha", "hausa"), ("ba", "bashkir"), ("jw", "javanese"), ("su", "sundanese"), ]; /// Returns the token id for the selected language. pub fn detect_language( model: &mut super::Model, tokenizer: &Tokenizer, mel: &Tensor, ) -> Result { let (_bsize, _, seq_len) = mel.dims3()?; let mel = mel.narrow( 2, 0, usize::min(seq_len, model.config().max_source_positions), )?; let device = mel.device(); let language_token_ids = LANGUAGES .iter() .map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>"))) .collect::>>()?; let sot_token = crate::token_id(tokenizer, crate::m::SOT_TOKEN)?; let audio_features = model.encoder_forward(&mel, true)?; let tokens = Tensor::new(&[[sot_token]], device)?; let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?; let ys = model.decoder_forward(&tokens, &audio_features, true)?; let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?; let logits = logits.index_select(&language_token_ids, 0)?; let probs = candle_nn::ops::softmax(&logits, D::Minus1)?; let probs = probs.to_vec1::()?; let mut probs = LANGUAGES.iter().zip(probs.iter()).collect::>(); probs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1)); for ((_, language), p) in probs.iter().take(5) { println!("{language}: {p}") } let language = crate::token_id(tokenizer, &format!("<|{}|>", probs[0].0 .0))?; Ok(language) }