From e60f9b5dfcab05db8c3c2dbf55fb52f10188aa02 Mon Sep 17 00:00:00 2001 From: YiiSh Date: Thu, 14 Dec 2023 22:08:56 +0800 Subject: Speedup ShardedSafeTensors to load Tensors with default hints (#1384) * Speedup ShardedSafeTensors to load Tensors with default hints * Tweaks. --------- Co-authored-by: Laurent --- candle-nn/src/var_builder.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'candle-nn') diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index cbd238dd..9d245f12 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -535,12 +535,18 @@ impl Backend for ShardedSafeTensors { fn get( &self, - _target_shape: Shape, // The size is not checked for ShardedTensors + target_shape: Shape, // The size is only checked when the world size is 1. path: &str, h: Self::Hints, dtype: DType, dev: &Device, ) -> Result { + if h.world_size == 1 { + // There is no sharding to be applied here so we use the default backend to speed + // things up. + return SimpleBackend::get(&self.0, target_shape, path, Default::default(), dtype, dev); + } + let Shard { dim, rank, -- cgit v1.2.3