summaryrefslogtreecommitdiff
path: root/candle-examples/examples/t5/main.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-21 14:48:13 +0100
committerGitHub <noreply@github.com>2023-09-21 14:48:13 +0100
commitaa8ec06fd2b02c1039a46fcb518fd6d351487978 (patch)
tree8823944439182728a108ef096804299b1654eee3 /candle-examples/examples/t5/main.rs
parentb43ca493f67a98aa6a6f53144ecb17a0a0d20fd0 (diff)
downloadcandle-aa8ec06fd2b02c1039a46fcb518fd6d351487978.tar.gz
candle-aa8ec06fd2b02c1039a46fcb518fd6d351487978.tar.bz2
candle-aa8ec06fd2b02c1039a46fcb518fd6d351487978.zip
Add the t5-xxl version. (#924)
Diffstat (limited to 'candle-examples/examples/t5/main.rs')
-rw-r--r--candle-examples/examples/t5/main.rs73
1 files changed, 37 insertions, 36 deletions
diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs
index 348e9a55..55929c33 100644
--- a/candle-examples/examples/t5/main.rs
+++ b/candle-examples/examples/t5/main.rs
@@ -8,12 +8,12 @@ use std::path::PathBuf;
use candle_transformers::models::t5;
-use anyhow::{anyhow, Error as E, Result};
+use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use clap::Parser;
-use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
+use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
const DTYPE: DType = DType::F32;
@@ -25,10 +25,6 @@ struct Args {
#[arg(long)]
cpu: bool,
- /// Run offline (you must have the files already cached)
- #[arg(long)]
- offline: bool,
-
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
@@ -80,7 +76,7 @@ struct Args {
struct T5ModelBuilder {
device: Device,
config: t5::Config,
- weights_filename: PathBuf,
+ weights_filename: Vec<PathBuf>,
}
impl T5ModelBuilder {
@@ -95,28 +91,21 @@ impl T5ModelBuilder {
(None, None) => (default_model, default_revision),
};
- let repo = Repo::with_revision(model_id, RepoType::Model, revision);
- let (config_filename, tokenizer_filename, weights_filename) = if args.offline {
- let cache = Cache::default().repo(repo);
- (
- cache
- .get("config.json")
- .ok_or(anyhow!("Missing config file in cache"))?,
- cache
- .get("tokenizer.json")
- .ok_or(anyhow!("Missing tokenizer file in cache"))?,
- cache
- .get("model.safetensors")
- .ok_or(anyhow!("Missing weights file in cache"))?,
- )
+ let repo = Repo::with_revision(model_id.clone(), RepoType::Model, revision);
+ let api = Api::new()?;
+ let api = api.repo(repo);
+ let config_filename = api.get("config.json")?;
+ let tokenizer_filename = api.get("tokenizer.json")?;
+ let weights_filename = if model_id == "google/flan-t5-xxl" {
+ vec![
+ api.get("model-00001-of-00005.safetensors")?,
+ api.get("model-00002-of-00005.safetensors")?,
+ api.get("model-00003-of-00005.safetensors")?,
+ api.get("model-00004-of-00005.safetensors")?,
+ api.get("model-00005-of-00005.safetensors")?,
+ ]
} else {
- let api = Api::new()?;
- let api = api.repo(repo);
- (
- api.get("config.json")?,
- api.get("tokenizer.json")?,
- api.get("model.safetensors")?,
- )
+ vec![api.get("model.safetensors")?]
};
let config = std::fs::read_to_string(config_filename)?;
let mut config: t5::Config = serde_json::from_str(&config)?;
@@ -133,18 +122,30 @@ impl T5ModelBuilder {
}
pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> {
- let weights =
- unsafe { candle::safetensors::MmapedFile::new(self.weights_filename.clone())? };
- let weights = weights.deserialize()?;
- let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &self.device);
+ let weights = self
+ .weights_filename
+ .iter()
+ .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
+ .collect::<candle::Result<Vec<_>>>()?;
+ let weights = weights
+ .iter()
+ .map(|w| w.deserialize())
+ .collect::<candle::Result<Vec<_>>>()?;
+ let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
Ok(t5::T5EncoderModel::load(vb, &self.config)?)
}
pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> {
- let weights =
- unsafe { candle::safetensors::MmapedFile::new(self.weights_filename.clone())? };
- let weights = weights.deserialize()?;
- let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &self.device);
+ let weights = self
+ .weights_filename
+ .iter()
+ .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
+ .collect::<candle::Result<Vec<_>>>()?;
+ let weights = weights
+ .iter()
+ .map(|w| w.deserialize())
+ .collect::<candle::Result<Vec<_>>>()?;
+ let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
}
}