summaryrefslogtreecommitdiff
path: root/candle-examples/examples/t5
diff options
context:
space:
mode:
authorJuarez Bochi <juarez.bochi@grammarly.com>2023-09-12 23:14:05 -0700
committerGitHub <noreply@github.com>2023-09-13 07:14:05 +0100
commit9daa6dbe87a6cb11496941acb4d7d5fb785183f8 (patch)
tree8824323fc15731cfaf003a5895039bad15ec2dbe /candle-examples/examples/t5
parente82fcf1c594b54c105f1a3979a09f3d2e044a2e0 (diff)
downloadcandle-9daa6dbe87a6cb11496941acb4d7d5fb785183f8.tar.gz
candle-9daa6dbe87a6cb11496941acb4d7d5fb785183f8.tar.bz2
candle-9daa6dbe87a6cb11496941acb4d7d5fb785183f8.zip
Extract T5 module and add main function to use it (#829)
* Extract t5 out of musicgen * Add main for t5 module
Diffstat (limited to 'candle-examples/examples/t5')
-rw-r--r--candle-examples/examples/t5/README.md17
-rw-r--r--candle-examples/examples/t5/main.rs134
2 files changed, 151 insertions, 0 deletions
diff --git a/candle-examples/examples/t5/README.md b/candle-examples/examples/t5/README.md
new file mode 100644
index 00000000..66952395
--- /dev/null
+++ b/candle-examples/examples/t5/README.md
@@ -0,0 +1,17 @@
+# candle-t5
+
+Generates embeddings using a T5 model. It doesn't support generation yet.
+
+```bash
+$ cargo run --example t5 -- --model-id t5-large --prompt 'how tall is obama' --n 1
+Loaded and encoded 2.014244792s
+[[[-0.3174, -0.1462, 0.0065, ..., -0.0579, -0.0581, 0.1387],
+ [-0.2905, -0.1945, -0.0685, ..., -0.2457, -0.5137, -0.1760],
+ [-0.0591, -0.0213, -0.0241, ..., -0.0210, 0.0491, -0.0300],
+ ...
+ [-0.4333, 0.0027, -0.0609, ..., 0.3069, -0.2252, 0.3306],
+ [-0.1458, 0.1323, -0.0138, ..., 0.3000, -0.4550, -0.0384],
+ [ 0.0397, 0.0485, -0.2373, ..., 0.2578, -0.2650, -0.4356]]]
+Tensor[[1, 9, 1024], f32]
+Took 2.1363425s
+``` \ No newline at end of file
diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs
new file mode 100644
index 00000000..bcba846d
--- /dev/null
+++ b/candle-examples/examples/t5/main.rs
@@ -0,0 +1,134 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+use candle_transformers::models::t5;
+
+use anyhow::{anyhow, Error as E, Result};
+use candle::{DType, Tensor};
+use candle_nn::VarBuilder;
+use clap::Parser;
+use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
+use tokenizers::Tokenizer;
+
+const DTYPE: DType = DType::F32;
+const DEFAULT_PROMPT: &str = "Translate English to German: That is good.";
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ /// Run offline (you must have the files already cached)
+ #[arg(long)]
+ offline: bool,
+
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+
+ /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
+ #[arg(long)]
+ model_id: Option<String>,
+
+ #[arg(long)]
+ revision: Option<String>,
+
+ /// Compute embeddings for this prompt or use the DEFAULT_PROMPT.
+ #[arg(long)]
+ prompt: Option<String>,
+
+ /// The number of times to run the prompt.
+ #[arg(long, default_value = "1")]
+ n: usize,
+}
+
+impl Args {
+ fn build_model_and_tokenizer(&self) -> Result<(t5::T5EncoderModel, Tokenizer)> {
+ let device = candle_examples::device(self.cpu)?;
+ let default_model = "t5-small".to_string();
+ let default_revision = "refs/pr/15".to_string();
+ let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
+ (Some(model_id), Some(revision)) => (model_id, revision),
+ (Some(model_id), None) => (model_id, "main".to_string()),
+ (None, Some(revision)) => (default_model, revision),
+ (None, None) => (default_model, default_revision),
+ };
+
+ let repo = Repo::with_revision(model_id, RepoType::Model, revision);
+ let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
+ let cache = Cache::default().repo(repo);
+ (
+ cache
+ .get("config.json")
+ .ok_or(anyhow!("Missing config file in cache"))?,
+ cache
+ .get("tokenizer.json")
+ .ok_or(anyhow!("Missing tokenizer file in cache"))?,
+ cache
+ .get("model.safetensors")
+ .ok_or(anyhow!("Missing weights file in cache"))?,
+ )
+ } else {
+ let api = Api::new()?;
+ let api = api.repo(repo);
+ (
+ api.get("config.json")?,
+ api.get("tokenizer.json")?,
+ api.get("model.safetensors")?,
+ )
+ };
+ let config = std::fs::read_to_string(config_filename)?;
+ let config: t5::Config = serde_json::from_str(&config)?;
+ let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
+
+ let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
+ let weights = weights.deserialize()?;
+ let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
+ let model = t5::T5EncoderModel::load(vb, &config)?;
+ Ok((model, tokenizer))
+ }
+}
+
+fn main() -> Result<()> {
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
+
+ let args = Args::parse();
+ let _guard = if args.tracing {
+ println!("tracing...");
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
+ let start = std::time::Instant::now();
+
+ let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
+ let device = &model.device;
+ let prompt = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
+ let tokenizer = tokenizer
+ .with_padding(None)
+ .with_truncation(None)
+ .map_err(E::msg)?;
+ let tokens = tokenizer
+ .encode(prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
+ println!("Loaded and encoded {:?}", start.elapsed());
+ for idx in 0..args.n {
+ let start = std::time::Instant::now();
+ let ys = model.forward(&token_ids)?;
+ if idx == 0 {
+ println!("{ys}");
+ }
+ println!("Took {:?}", start.elapsed());
+ }
+ Ok(())
+}