summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-03-30 13:22:00 +0100
committerGitHub <noreply@github.com>2024-03-30 13:22:00 +0100
commitb190fd85920dfeb93c091593d42fda596c3a83a7 (patch)
tree52003bc7540277e37f006198e77a145fbf8dccd4 /candle-core/src
parentefe4a0c84b55b60f7555a89ea7e0ba8d300104cd (diff)
downloadcandle-b190fd85920dfeb93c091593d42fda596c3a83a7.tar.gz
candle-b190fd85920dfeb93c091593d42fda596c3a83a7.tar.bz2
candle-b190fd85920dfeb93c091593d42fda596c3a83a7.zip
Remove some unnecessary calls to contiguous. (#1968)
* Remove some unnecessary calls to contiguous. * Slightly improved kv cache concatenation.
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/tensor_cat.rs22
1 files changed, 10 insertions, 12 deletions
diff --git a/candle-core/src/tensor_cat.rs b/candle-core/src/tensor_cat.rs
index 31cc8503..27ff7851 100644
--- a/candle-core/src/tensor_cat.rs
+++ b/candle-core/src/tensor_cat.rs
@@ -58,20 +58,18 @@ impl Tensor {
}
}
}
- if dim == 0 {
+ let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
+ if all_contiguous {
+ Self::cat_contiguous(args, dim)
+ } else if dim == 0 {
Self::cat0(args)
} else {
- let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
- if all_contiguous {
- Self::cat_contiguous(args, dim)
- } else {
- let args: Vec<Tensor> = args
- .iter()
- .map(|a| a.as_ref().transpose(0, dim))
- .collect::<Result<Vec<_>>>()?;
- let cat = Self::cat0(&args)?;
- cat.transpose(0, dim)
- }
+ let args: Vec<Tensor> = args
+ .iter()
+ .map(|a| a.as_ref().transpose(0, dim))
+ .collect::<Result<Vec<_>>>()?;
+ let cat = Self::cat0(&args)?;
+ cat.transpose(0, dim)
}
}