summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/bert/main.rs5
-rw-r--r--candle-examples/examples/bigcode/main.rs11
-rw-r--r--candle-examples/examples/dinov2/main.rs4
-rw-r--r--candle-examples/examples/efficientnet/main.rs4
-rw-r--r--candle-examples/examples/falcon/main.rs11
-rw-r--r--candle-examples/examples/llama/main.rs10
-rw-r--r--candle-examples/examples/musicgen/main.rs4
-rw-r--r--candle-examples/examples/phi/main.rs11
-rw-r--r--candle-examples/examples/segment-anything/main.rs4
-rw-r--r--candle-examples/examples/whisper/main.rs5
-rw-r--r--candle-examples/examples/wuerstchen/main.rs24
-rw-r--r--candle-examples/examples/yolo-v3/main.rs4
-rw-r--r--candle-examples/examples/yolo-v8/main.rs4
-rw-r--r--candle-transformers/src/models/stable_diffusion/mod.rs14
14 files changed, 31 insertions, 84 deletions
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs
index 9d0eccdf..70592013 100644
--- a/candle-examples/examples/bert/main.rs
+++ b/candle-examples/examples/bert/main.rs
@@ -86,9 +86,8 @@ impl Args {
let config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
- let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
- let weights = weights.deserialize()?;
- let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
+ let vb =
+ unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? };
let model = BertModel::load(vb, &config)?;
Ok((model, tokenizer))
}
diff --git a/candle-examples/examples/bigcode/main.rs b/candle-examples/examples/bigcode/main.rs
index 5f17109e..bf8dd24c 100644
--- a/candle-examples/examples/bigcode/main.rs
+++ b/candle-examples/examples/bigcode/main.rs
@@ -138,18 +138,9 @@ fn main() -> Result<()> {
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
- let weights = filenames
- .iter()
- .map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
- .collect::<Result<Vec<_>>>()?;
- let weights = weights
- .iter()
- .map(|f| Ok(f.deserialize()?))
- .collect::<Result<Vec<_>>>()?;
-
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
- let vb = VarBuilder::from_safetensors(weights, DType::F32, &device);
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let config = Config::starcoder_1b();
let model = GPTBigCode::load(vb, config)?;
println!("loaded the model in {:?}", start.elapsed());
diff --git a/candle-examples/examples/dinov2/main.rs b/candle-examples/examples/dinov2/main.rs
index d3adb37c..6b3edeb4 100644
--- a/candle-examples/examples/dinov2/main.rs
+++ b/candle-examples/examples/dinov2/main.rs
@@ -42,9 +42,7 @@ pub fn main() -> anyhow::Result<()> {
}
Some(model) => model.into(),
};
- let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
- let weights = weights.deserialize()?;
- let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = dinov2::vit_small(vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
diff --git a/candle-examples/examples/efficientnet/main.rs b/candle-examples/examples/efficientnet/main.rs
index 1e45e301..0e4a2864 100644
--- a/candle-examples/examples/efficientnet/main.rs
+++ b/candle-examples/examples/efficientnet/main.rs
@@ -68,9 +68,7 @@ pub fn main() -> anyhow::Result<()> {
}
Some(model) => model.into(),
};
- let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
- let weights = weights.deserialize()?;
- let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let cfg = match args.which {
Which::B0 => MBConvConfig::b0(),
Which::B1 => MBConvConfig::b1(),
diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs
index b0973d64..1cef25a8 100644
--- a/candle-examples/examples/falcon/main.rs
+++ b/candle-examples/examples/falcon/main.rs
@@ -177,21 +177,12 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
- let weights = filenames
- .iter()
- .map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
- .collect::<Result<Vec<_>>>()?;
- let weights = weights
- .iter()
- .map(|f| Ok(f.deserialize()?))
- .collect::<Result<Vec<_>>>()?;
-
let dtype = if args.use_f32 {
DType::F32
} else {
DType::BF16
};
- let vb = VarBuilder::from_safetensors(weights, dtype, &device);
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let config = Config::falcon7b();
config.validate()?;
let model = Falcon::load(vb, config)?;
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index b2d7d938..4bf91d92 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -172,17 +172,9 @@ fn main() -> Result<()> {
}
println!("building the model");
- let handles = filenames
- .iter()
- .map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? }))
- .collect::<Result<Vec<_>>>()?;
- let tensors: Vec<_> = handles
- .iter()
- .map(|h| Ok(h.deserialize()?))
- .collect::<Result<Vec<_>>>()?;
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
- let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
}
};
diff --git a/candle-examples/examples/musicgen/main.rs b/candle-examples/examples/musicgen/main.rs
index 0fae67b5..a39cfec2 100644
--- a/candle-examples/examples/musicgen/main.rs
+++ b/candle-examples/examples/musicgen/main.rs
@@ -73,9 +73,7 @@ fn main() -> Result<()> {
))
.get("model.safetensors")?,
};
- let model = unsafe { candle::safetensors::MmapedFile::new(model)? };
- let model = model.deserialize()?;
- let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DTYPE, &device)? };
let config = GenConfig::small();
let mut model = MusicgenForConditionalGeneration::load(vb, config)?;
diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs
index 25c7db98..3b1e7dc1 100644
--- a/candle-examples/examples/phi/main.rs
+++ b/candle-examples/examples/phi/main.rs
@@ -149,18 +149,9 @@ fn main() -> Result<()> {
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
- let weights = filenames
- .iter()
- .map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
- .collect::<Result<Vec<_>>>()?;
- let weights = weights
- .iter()
- .map(|f| Ok(f.deserialize()?))
- .collect::<Result<Vec<_>>>()?;
-
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
- let vb = VarBuilder::from_safetensors(weights, DType::F32, &device);
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let config = Config::v1_5();
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs
index 3d9898b6..71abe116 100644
--- a/candle-examples/examples/segment-anything/main.rs
+++ b/candle-examples/examples/segment-anything/main.rs
@@ -82,9 +82,7 @@ pub fn main() -> anyhow::Result<()> {
api.get(filename)?
}
};
- let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
- let weights = weights.deserialize()?;
- let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
let sam = if args.use_tiny {
sam::Sam::new_tiny(vb)? // tiny vit_t
} else {
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index c71d562a..0aa4db41 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -481,9 +481,8 @@ fn main() -> Result<()> {
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
println!("loaded mel: {:?}", mel.dims());
- let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
- let weights = weights.deserialize()?;
- let vb = VarBuilder::from_safetensors(vec![weights], m::DTYPE, &device);
+ let vb =
+ unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? };
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let mut model = Whisper::load(&vb, config)?;
diff --git a/candle-examples/examples/wuerstchen/main.rs b/candle-examples/examples/wuerstchen/main.rs
index 95f3b8f4..40b43c1d 100644
--- a/candle-examples/examples/wuerstchen/main.rs
+++ b/candle-examples/examples/wuerstchen/main.rs
@@ -287,10 +287,10 @@ fn run(args: Args) -> Result<()> {
)?;
let prior = {
- let prior_weights = ModelFile::Prior.get(prior_weights)?;
- let weights = unsafe { candle::safetensors::MmapedFile::new(prior_weights)? };
- let weights = weights.deserialize()?;
- let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
+ let file = ModelFile::Prior.get(prior_weights)?;
+ let vb = unsafe {
+ candle_nn::VarBuilder::from_mmaped_safetensors(&[file], DType::F32, &device)?
+ };
wuerstchen::prior::WPrior::new(
/* c_in */ PRIOR_CIN,
/* c */ 1536,
@@ -324,10 +324,10 @@ fn run(args: Args) -> Result<()> {
println!("Building the vqgan.");
let vqgan = {
- let vqgan_weights = ModelFile::VqGan.get(vqgan_weights)?;
- let weights = unsafe { candle::safetensors::MmapedFile::new(vqgan_weights)? };
- let weights = weights.deserialize()?;
- let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
+ let file = ModelFile::VqGan.get(vqgan_weights)?;
+ let vb = unsafe {
+ candle_nn::VarBuilder::from_mmaped_safetensors(&[file], DType::F32, &device)?
+ };
wuerstchen::paella_vq::PaellaVQ::new(vb)?
};
@@ -335,10 +335,10 @@ fn run(args: Args) -> Result<()> {
// https://huggingface.co/warp-ai/wuerstchen/blob/main/decoder/config.json
let decoder = {
- let decoder_weights = ModelFile::Decoder.get(decoder_weights)?;
- let weights = unsafe { candle::safetensors::MmapedFile::new(decoder_weights)? };
- let weights = weights.deserialize()?;
- let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
+ let file = ModelFile::Decoder.get(decoder_weights)?;
+ let vb = unsafe {
+ candle_nn::VarBuilder::from_mmaped_safetensors(&[file], DType::F32, &device)?
+ };
wuerstchen::diffnext::WDiffNeXt::new(
/* c_in */ DECODER_CIN,
/* c_out */ DECODER_CIN,
diff --git a/candle-examples/examples/yolo-v3/main.rs b/candle-examples/examples/yolo-v3/main.rs
index ecf75bdf..5b1937ac 100644
--- a/candle-examples/examples/yolo-v3/main.rs
+++ b/candle-examples/examples/yolo-v3/main.rs
@@ -146,9 +146,7 @@ pub fn main() -> Result<()> {
// Create the model and load the weights from the file.
let model = args.model()?;
- let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
- let weights = weights.deserialize()?;
- let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &Device::Cpu)? };
let config = args.config()?;
let darknet = darknet::parse_config(config)?;
let model = darknet.build_model(vb)?;
diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs
index dc709db4..af8cf98a 100644
--- a/candle-examples/examples/yolo-v8/main.rs
+++ b/candle-examples/examples/yolo-v8/main.rs
@@ -381,9 +381,7 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
Which::X => Multiples::x(),
};
let model = args.model()?;
- let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
- let weights = weights.deserialize()?;
- let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
let model = T::load(vb, multiples)?;
println!("model loaded");
for image_name in args.images.iter() {
diff --git a/candle-transformers/src/models/stable_diffusion/mod.rs b/candle-transformers/src/models/stable_diffusion/mod.rs
index c6f1b904..7fdedaae 100644
--- a/candle-transformers/src/models/stable_diffusion/mod.rs
+++ b/candle-transformers/src/models/stable_diffusion/mod.rs
@@ -255,9 +255,8 @@ impl StableDiffusionConfig {
device: &Device,
dtype: DType,
) -> Result<vae::AutoEncoderKL> {
- let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? };
- let weights = weights.deserialize()?;
- let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
+ let vs_ae =
+ unsafe { nn::VarBuilder::from_mmaped_safetensors(&[vae_weights], dtype, device)? };
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
let autoencoder = vae::AutoEncoderKL::new(vs_ae, 3, 3, self.autoencoder.clone())?;
Ok(autoencoder)
@@ -271,9 +270,8 @@ impl StableDiffusionConfig {
use_flash_attn: bool,
dtype: DType,
) -> Result<unet_2d::UNet2DConditionModel> {
- let weights = unsafe { candle::safetensors::MmapedFile::new(unet_weights)? };
- let weights = weights.deserialize()?;
- let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
+ let vs_unet =
+ unsafe { nn::VarBuilder::from_mmaped_safetensors(&[unet_weights], dtype, device)? };
let unet = unet_2d::UNet2DConditionModel::new(
vs_unet,
in_channels,
@@ -295,9 +293,7 @@ pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
device: &Device,
dtype: DType,
) -> Result<clip::ClipTextTransformer> {
- let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
- let weights = weights.deserialize()?;
- let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
+ let vs = unsafe { nn::VarBuilder::from_mmaped_safetensors(&[clip_weights], dtype, device)? };
let text_model = clip::ClipTextTransformer::new(vs, clip)?;
Ok(text_model)
}