summaryrefslogtreecommitdiff
path: root/candle-examples/examples/t5/main.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-23 20:39:52 +0100
committerGitHub <noreply@github.com>2023-09-23 20:39:52 +0100
commit890d069092a3158838b82f3d8fbdf709c84e8770 (patch)
tree4fbdf8b246b20bdacc0386aea901a2d985667d55 /candle-examples/examples/t5/main.rs
parent5dbe46b389da4ba39131ce34752249cae640ad9e (diff)
downloadcandle-890d069092a3158838b82f3d8fbdf709c84e8770.tar.gz
candle-890d069092a3158838b82f3d8fbdf709c84e8770.tar.bz2
candle-890d069092a3158838b82f3d8fbdf709c84e8770.zip
Self-contained safetensor wrappers (#946)
* Self-contained safetensor wrappers. * Use the new safetensor container in varbuilders.
Diffstat (limited to 'candle-examples/examples/t5/main.rs')
-rw-r--r--candle-examples/examples/t5/main.rs26
1 files changed, 6 insertions, 20 deletions
diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs
index 55929c33..71106497 100644
--- a/candle-examples/examples/t5/main.rs
+++ b/candle-examples/examples/t5/main.rs
@@ -122,30 +122,16 @@ impl T5ModelBuilder {
}
pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> {
- let weights = self
- .weights_filename
- .iter()
- .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
- .collect::<candle::Result<Vec<_>>>()?;
- let weights = weights
- .iter()
- .map(|w| w.deserialize())
- .collect::<candle::Result<Vec<_>>>()?;
- let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
+ let vb = unsafe {
+ VarBuilder::from_mmaped_safetensors(&self.weights_filename, DTYPE, &self.device)?
+ };
Ok(t5::T5EncoderModel::load(vb, &self.config)?)
}
pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> {
- let weights = self
- .weights_filename
- .iter()
- .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
- .collect::<candle::Result<Vec<_>>>()?;
- let weights = weights
- .iter()
- .map(|w| w.deserialize())
- .collect::<candle::Result<Vec<_>>>()?;
- let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
+ let vb = unsafe {
+ VarBuilder::from_mmaped_safetensors(&self.weights_filename, DTYPE, &self.device)?
+ };
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
}
}