diff options
Diffstat (limited to 'candle-examples/examples/t5/main.rs')
-rw-r--r-- | candle-examples/examples/t5/main.rs | 26 |
1 files changed, 6 insertions, 20 deletions
diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs index 55929c33..71106497 100644 --- a/candle-examples/examples/t5/main.rs +++ b/candle-examples/examples/t5/main.rs @@ -122,30 +122,16 @@ impl T5ModelBuilder { } pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> { - let weights = self - .weights_filename - .iter() - .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) }) - .collect::<candle::Result<Vec<_>>>()?; - let weights = weights - .iter() - .map(|w| w.deserialize()) - .collect::<candle::Result<Vec<_>>>()?; - let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device); + let vb = unsafe { + VarBuilder::from_mmaped_safetensors(&self.weights_filename, DTYPE, &self.device)? + }; Ok(t5::T5EncoderModel::load(vb, &self.config)?) } pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> { - let weights = self - .weights_filename - .iter() - .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) }) - .collect::<candle::Result<Vec<_>>>()?; - let weights = weights - .iter() - .map(|w| w.deserialize()) - .collect::<candle::Result<Vec<_>>>()?; - let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device); + let vb = unsafe { + VarBuilder::from_mmaped_safetensors(&self.weights_filename, DTYPE, &self.device)? + }; Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?) } } |