diff options
Diffstat (limited to 'candle-examples/examples/marian-mt/main.rs')
-rw-r--r-- | candle-examples/examples/marian-mt/main.rs | 90 |
1 files changed, 90 insertions, 0 deletions
diff --git a/candle-examples/examples/marian-mt/main.rs b/candle-examples/examples/marian-mt/main.rs new file mode 100644 index 00000000..ed044627 --- /dev/null +++ b/candle-examples/examples/marian-mt/main.rs @@ -0,0 +1,90 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::Error as E; +use clap::Parser; + +use candle::{DType, Tensor}; +use candle_examples::token_output_stream::TokenOutputStream; +use candle_nn::VarBuilder; +use candle_transformers::models::marian; + +use tokenizers::Tokenizer; + +// TODO: Maybe add support for the conditional prompt. +#[derive(Parser)] +struct Args { + #[arg(long)] + model: String, + + #[arg(long)] + tokenizer: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Use the quantized version of the model. + #[arg(long)] + quantized: bool, + + /// Text to be translated + #[arg(long)] + text: String, +} + +const SEP_TOKEN_ID: u32 = 102; + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + let config = marian::Config::opus_mt_tc_big_fr_en(); + + let device = candle_examples::device(args.cpu)?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[&args.model], DType::F32, &device)? }; + let model = marian::MTModel::new(&config, vb)?; + + let tokenizer = Tokenizer::from_file(&args.tokenizer).map_err(E::msg)?; + let mut tokenizer_dec = TokenOutputStream::new(tokenizer.clone()); + let mut logits_processor = + candle_transformers::generation::LogitsProcessor::new(1337, None, None); + + let encoder_xs = { + let tokens = tokenizer + .encode(args.text, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?; + model.encoder().forward(&tokens, 0)? + }; + + let mut token_ids = vec![30522u32]; + for index in 0..1000 { + // TODO: Add a kv cache. + let context_size = if index >= 1000 { 1 } else { token_ids.len() }; + let start_pos = token_ids.len().saturating_sub(context_size); + let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?; + let logits = model.decode(&input_ids, &encoder_xs)?; + let logits = logits.squeeze(0)?; + let logits = logits.get(logits.dim(0)? - 1)?; + let token = logits_processor.sample(&logits)?; + if token == SEP_TOKEN_ID { + break; + } + token_ids.push(token); + if let Some(t) = tokenizer_dec.next_token(token)? { + use std::io::Write; + print!("{t}"); + std::io::stdout().flush()?; + } + } + if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? { + print!("{rest}"); + } + + Ok(()) +} |