summaryrefslogtreecommitdiff
path: root/candle-examples/examples/t5/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/t5/main.rs')
-rw-r--r--candle-examples/examples/t5/main.rs26
1 files changed, 6 insertions, 20 deletions
diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs
index 55929c33..71106497 100644
--- a/candle-examples/examples/t5/main.rs
+++ b/candle-examples/examples/t5/main.rs
@@ -122,30 +122,16 @@ impl T5ModelBuilder {
}
pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> {
- let weights = self
- .weights_filename
- .iter()
- .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
- .collect::<candle::Result<Vec<_>>>()?;
- let weights = weights
- .iter()
- .map(|w| w.deserialize())
- .collect::<candle::Result<Vec<_>>>()?;
- let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
+ let vb = unsafe {
+ VarBuilder::from_mmaped_safetensors(&self.weights_filename, DTYPE, &self.device)?
+ };
Ok(t5::T5EncoderModel::load(vb, &self.config)?)
}
pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> {
- let weights = self
- .weights_filename
- .iter()
- .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
- .collect::<candle::Result<Vec<_>>>()?;
- let weights = weights
- .iter()
- .map(|w| w.deserialize())
- .collect::<candle::Result<Vec<_>>>()?;
- let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
+ let vb = unsafe {
+ VarBuilder::from_mmaped_safetensors(&self.weights_filename, DTYPE, &self.device)?
+ };
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
}
}