summaryrefslogtreecommitdiff
path: root/candle-examples/examples/marian-mt/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/marian-mt/main.rs')
-rw-r--r--candle-examples/examples/marian-mt/main.rs90
1 files changed, 90 insertions, 0 deletions
diff --git a/candle-examples/examples/marian-mt/main.rs b/candle-examples/examples/marian-mt/main.rs
new file mode 100644
index 00000000..ed044627
--- /dev/null
+++ b/candle-examples/examples/marian-mt/main.rs
@@ -0,0 +1,90 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+use anyhow::Error as E;
+use clap::Parser;
+
+use candle::{DType, Tensor};
+use candle_examples::token_output_stream::TokenOutputStream;
+use candle_nn::VarBuilder;
+use candle_transformers::models::marian;
+
+use tokenizers::Tokenizer;
+
+// TODO: Maybe add support for the conditional prompt.
+#[derive(Parser)]
+struct Args {
+ #[arg(long)]
+ model: String,
+
+ #[arg(long)]
+ tokenizer: String,
+
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ /// Use the quantized version of the model.
+ #[arg(long)]
+ quantized: bool,
+
+ /// Text to be translated
+ #[arg(long)]
+ text: String,
+}
+
+const SEP_TOKEN_ID: u32 = 102;
+
+pub fn main() -> anyhow::Result<()> {
+ let args = Args::parse();
+
+ let config = marian::Config::opus_mt_tc_big_fr_en();
+
+ let device = candle_examples::device(args.cpu)?;
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[&args.model], DType::F32, &device)? };
+ let model = marian::MTModel::new(&config, vb)?;
+
+ let tokenizer = Tokenizer::from_file(&args.tokenizer).map_err(E::msg)?;
+ let mut tokenizer_dec = TokenOutputStream::new(tokenizer.clone());
+ let mut logits_processor =
+ candle_transformers::generation::LogitsProcessor::new(1337, None, None);
+
+ let encoder_xs = {
+ let tokens = tokenizer
+ .encode(args.text, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
+ model.encoder().forward(&tokens, 0)?
+ };
+
+ let mut token_ids = vec![30522u32];
+ for index in 0..1000 {
+ // TODO: Add a kv cache.
+ let context_size = if index >= 1000 { 1 } else { token_ids.len() };
+ let start_pos = token_ids.len().saturating_sub(context_size);
+ let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
+ let logits = model.decode(&input_ids, &encoder_xs)?;
+ let logits = logits.squeeze(0)?;
+ let logits = logits.get(logits.dim(0)? - 1)?;
+ let token = logits_processor.sample(&logits)?;
+ if token == SEP_TOKEN_ID {
+ break;
+ }
+ token_ids.push(token);
+ if let Some(t) = tokenizer_dec.next_token(token)? {
+ use std::io::Write;
+ print!("{t}");
+ std::io::stdout().flush()?;
+ }
+ }
+ if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
+ print!("{rest}");
+ }
+
+ Ok(())
+}