summaryrefslogtreecommitdiff
path: root/candle-core/src/tensor_cat.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/tensor_cat.rs')
-rw-r--r--candle-core/src/tensor_cat.rs22
1 files changed, 10 insertions, 12 deletions
diff --git a/candle-core/src/tensor_cat.rs b/candle-core/src/tensor_cat.rs
index 31cc8503..27ff7851 100644
--- a/candle-core/src/tensor_cat.rs
+++ b/candle-core/src/tensor_cat.rs
@@ -58,20 +58,18 @@ impl Tensor {
}
}
}
- if dim == 0 {
+ let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
+ if all_contiguous {
+ Self::cat_contiguous(args, dim)
+ } else if dim == 0 {
Self::cat0(args)
} else {
- let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
- if all_contiguous {
- Self::cat_contiguous(args, dim)
- } else {
- let args: Vec<Tensor> = args
- .iter()
- .map(|a| a.as_ref().transpose(0, dim))
- .collect::<Result<Vec<_>>>()?;
- let cat = Self::cat0(&args)?;
- cat.transpose(0, dim)
- }
+ let args: Vec<Tensor> = args
+ .iter()
+ .map(|a| a.as_ref().transpose(0, dim))
+ .collect::<Result<Vec<_>>>()?;
+ let cat = Self::cat0(&args)?;
+ cat.transpose(0, dim)
}
}