diff options
author | YiiSh <mokeyish@hotmail.com> | 2023-12-14 22:08:56 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-14 08:08:56 -0600 |
commit | e60f9b5dfcab05db8c3c2dbf55fb52f10188aa02 (patch) | |
tree | 1f4486ec48e37401c55dad0521d00053e02f0a1a /candle-nn | |
parent | 7be982f6f77fa26fab09792130e1fd707bc728be (diff) | |
download | candle-e60f9b5dfcab05db8c3c2dbf55fb52f10188aa02.tar.gz candle-e60f9b5dfcab05db8c3c2dbf55fb52f10188aa02.tar.bz2 candle-e60f9b5dfcab05db8c3c2dbf55fb52f10188aa02.zip |
Speedup ShardedSafeTensors to load Tensors with default hints (#1384)
* Speedup ShardedSafeTensors to load Tensors with default hints
* Tweaks.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/var_builder.rs | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index cbd238dd..9d245f12 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -535,12 +535,18 @@ impl Backend for ShardedSafeTensors { fn get( &self, - _target_shape: Shape, // The size is not checked for ShardedTensors + target_shape: Shape, // The size is only checked when the world size is 1. path: &str, h: Self::Hints, dtype: DType, dev: &Device, ) -> Result<Tensor> { + if h.world_size == 1 { + // There is no sharding to be applied here so we use the default backend to speed + // things up. + return SimpleBackend::get(&self.0, target_shape, path, Default::default(), dtype, dev); + } + let Shard { dim, rank, |