summaryrefslogtreecommitdiff
path: root/candle-nn/src/var_builder.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src/var_builder.rs')
-rw-r--r--candle-nn/src/var_builder.rs10
1 files changed, 8 insertions, 2 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index cbd238dd..83c86a6f 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -40,7 +40,7 @@ struct TensorData<B: Backend> {
/// A trait that defines how tensor data is retrieved.
///
/// Typically this would use disk storage in some specific format, or random initialization.
-/// Note that there is a speciliazed version of this trait (`SimpleBackend`) that can be used most
+/// Note that there is a specialized version of this trait (`SimpleBackend`) that can be used most
/// of the time. The main restriction is that it doesn't allow for specific args (besides
/// initialization hints).
pub trait Backend: Send + Sync {
@@ -535,12 +535,18 @@ impl Backend for ShardedSafeTensors {
fn get(
&self,
- _target_shape: Shape, // The size is not checked for ShardedTensors
+ target_shape: Shape, // The size is only checked when the world size is 1.
path: &str,
h: Self::Hints,
dtype: DType,
dev: &Device,
) -> Result<Tensor> {
+ if h.world_size == 1 {
+ // There is no sharding to be applied here so we use the default backend to speed
+ // things up.
+ return SimpleBackend::get(&self.0, target_shape, path, Default::default(), dtype, dev);
+ }
+
let Shard {
dim,
rank,