summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorYiiSh <mokeyish@hotmail.com>2023-12-14 22:08:56 +0800
committerGitHub <noreply@github.com>2023-12-14 08:08:56 -0600
commite60f9b5dfcab05db8c3c2dbf55fb52f10188aa02 (patch)
tree1f4486ec48e37401c55dad0521d00053e02f0a1a /candle-nn
parent7be982f6f77fa26fab09792130e1fd707bc728be (diff)
downloadcandle-e60f9b5dfcab05db8c3c2dbf55fb52f10188aa02.tar.gz
candle-e60f9b5dfcab05db8c3c2dbf55fb52f10188aa02.tar.bz2
candle-e60f9b5dfcab05db8c3c2dbf55fb52f10188aa02.zip
Speedup ShardedSafeTensors to load Tensors with default hints (#1384)
* Speedup ShardedSafeTensors to load Tensors with default hints * Tweaks. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/var_builder.rs8
1 files changed, 7 insertions, 1 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index cbd238dd..9d245f12 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -535,12 +535,18 @@ impl Backend for ShardedSafeTensors {
fn get(
&self,
- _target_shape: Shape, // The size is not checked for ShardedTensors
+ target_shape: Shape, // The size is only checked when the world size is 1.
path: &str,
h: Self::Hints,
dtype: DType,
dev: &Device,
) -> Result<Tensor> {
+ if h.world_size == 1 {
+ // There is no sharding to be applied here so we use the default backend to speed
+ // things up.
+ return SimpleBackend::get(&self.0, target_shape, path, Default::default(), dtype, dev);
+ }
+
let Shard {
dim,
rank,