summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-nn/src/var_builder.rs8
1 files changed, 7 insertions, 1 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index cbd238dd..9d245f12 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -535,12 +535,18 @@ impl Backend for ShardedSafeTensors {
fn get(
&self,
- _target_shape: Shape, // The size is not checked for ShardedTensors
+ target_shape: Shape, // The size is only checked when the world size is 1.
path: &str,
h: Self::Hints,
dtype: DType,
dev: &Device,
) -> Result<Tensor> {
+ if h.world_size == 1 {
+ // There is no sharding to be applied here so we use the default backend to speed
+ // things up.
+ return SimpleBackend::get(&self.0, target_shape, path, Default::default(), dtype, dev);
+ }
+
let Shard {
dim,
rank,