diff options
Diffstat (limited to 'candle-nn/src/var_builder.rs')
| -rw-r--r-- | candle-nn/src/var_builder.rs | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index cbd238dd..83c86a6f 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -40,7 +40,7 @@ struct TensorData<B: Backend> { /// A trait that defines how tensor data is retrieved. /// /// Typically this would use disk storage in some specific format, or random initialization. -/// Note that there is a speciliazed version of this trait (`SimpleBackend`) that can be used most +/// Note that there is a specialized version of this trait (`SimpleBackend`) that can be used most /// of the time. The main restriction is that it doesn't allow for specific args (besides /// initialization hints). pub trait Backend: Send + Sync { @@ -535,12 +535,18 @@ impl Backend for ShardedSafeTensors { fn get( &self, - _target_shape: Shape, // The size is not checked for ShardedTensors + target_shape: Shape, // The size is only checked when the world size is 1. path: &str, h: Self::Hints, dtype: DType, dev: &Device, ) -> Result<Tensor> { + if h.world_size == 1 { + // There is no sharding to be applied here so we use the default backend to speed + // things up. + return SimpleBackend::get(&self.0, target_shape, path, Default::default(), dtype, dev); + } + let Shard { dim, rank, |
