diff options
author | Juarez Bochi <juarez.bochi@grammarly.com> | 2023-09-12 23:14:05 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-13 07:14:05 +0100 |
commit | 9daa6dbe87a6cb11496941acb4d7d5fb785183f8 (patch) | |
tree | 8824323fc15731cfaf003a5895039bad15ec2dbe /candle-examples/examples/t5 | |
parent | e82fcf1c594b54c105f1a3979a09f3d2e044a2e0 (diff) | |
download | candle-9daa6dbe87a6cb11496941acb4d7d5fb785183f8.tar.gz candle-9daa6dbe87a6cb11496941acb4d7d5fb785183f8.tar.bz2 candle-9daa6dbe87a6cb11496941acb4d7d5fb785183f8.zip |
Extract T5 module and add main function to use it (#829)
* Extract t5 out of musicgen
* Add main for t5 module
Diffstat (limited to 'candle-examples/examples/t5')
-rw-r--r-- | candle-examples/examples/t5/README.md | 17 | ||||
-rw-r--r-- | candle-examples/examples/t5/main.rs | 134 |
2 files changed, 151 insertions, 0 deletions
diff --git a/candle-examples/examples/t5/README.md b/candle-examples/examples/t5/README.md new file mode 100644 index 00000000..66952395 --- /dev/null +++ b/candle-examples/examples/t5/README.md @@ -0,0 +1,17 @@ +# candle-t5 + +Generates embeddings using a T5 model. It doesn't support generation yet. + +```bash +$ cargo run --example t5 -- --model-id t5-large --prompt 'how tall is obama' --n 1 +Loaded and encoded 2.014244792s +[[[-0.3174, -0.1462, 0.0065, ..., -0.0579, -0.0581, 0.1387], + [-0.2905, -0.1945, -0.0685, ..., -0.2457, -0.5137, -0.1760], + [-0.0591, -0.0213, -0.0241, ..., -0.0210, 0.0491, -0.0300], + ... + [-0.4333, 0.0027, -0.0609, ..., 0.3069, -0.2252, 0.3306], + [-0.1458, 0.1323, -0.0138, ..., 0.3000, -0.4550, -0.0384], + [ 0.0397, 0.0485, -0.2373, ..., 0.2578, -0.2650, -0.4356]]] +Tensor[[1, 9, 1024], f32] +Took 2.1363425s +```
\ No newline at end of file diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs new file mode 100644 index 00000000..bcba846d --- /dev/null +++ b/candle-examples/examples/t5/main.rs @@ -0,0 +1,134 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; +use candle_transformers::models::t5; + +use anyhow::{anyhow, Error as E, Result}; +use candle::{DType, Tensor}; +use candle_nn::VarBuilder; +use clap::Parser; +use hf_hub::{api::sync::Api, Cache, Repo, RepoType}; +use tokenizers::Tokenizer; + +const DTYPE: DType = DType::F32; +const DEFAULT_PROMPT: &str = "Translate English to German: That is good."; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Run offline (you must have the files already cached) + #[arg(long)] + offline: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending + #[arg(long)] + model_id: Option<String>, + + #[arg(long)] + revision: Option<String>, + + /// Compute embeddings for this prompt or use the DEFAULT_PROMPT. + #[arg(long)] + prompt: Option<String>, + + /// The number of times to run the prompt. + #[arg(long, default_value = "1")] + n: usize, +} + +impl Args { + fn build_model_and_tokenizer(&self) -> Result<(t5::T5EncoderModel, Tokenizer)> { + let device = candle_examples::device(self.cpu)?; + let default_model = "t5-small".to_string(); + let default_revision = "refs/pr/15".to_string(); + let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) { + (Some(model_id), Some(revision)) => (model_id, revision), + (Some(model_id), None) => (model_id, "main".to_string()), + (None, Some(revision)) => (default_model, revision), + (None, None) => (default_model, default_revision), + }; + + let repo = Repo::with_revision(model_id, RepoType::Model, revision); + let (config_filename, tokenizer_filename, weights_filename) = if self.offline { + let cache = Cache::default().repo(repo); + ( + cache + .get("config.json") + .ok_or(anyhow!("Missing config file in cache"))?, + cache + .get("tokenizer.json") + .ok_or(anyhow!("Missing tokenizer file in cache"))?, + cache + .get("model.safetensors") + .ok_or(anyhow!("Missing weights file in cache"))?, + ) + } else { + let api = Api::new()?; + let api = api.repo(repo); + ( + api.get("config.json")?, + api.get("tokenizer.json")?, + api.get("model.safetensors")?, + ) + }; + let config = std::fs::read_to_string(config_filename)?; + let config: t5::Config = serde_json::from_str(&config)?; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? }; + let weights = weights.deserialize()?; + let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device); + let model = t5::T5EncoderModel::load(vb, &config)?; + Ok((model, tokenizer)) + } +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + println!("tracing..."); + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + let start = std::time::Instant::now(); + + let (model, mut tokenizer) = args.build_model_and_tokenizer()?; + let device = &model.device; + let prompt = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string()); + let tokenizer = tokenizer + .with_padding(None) + .with_truncation(None) + .map_err(E::msg)?; + let tokens = tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; + println!("Loaded and encoded {:?}", start.elapsed()); + for idx in 0..args.n { + let start = std::time::Instant::now(); + let ys = model.forward(&token_ids)?; + if idx == 0 { + println!("{ys}"); + } + println!("Took {:?}", start.elapsed()); + } + Ok(()) +} |