diff options
Diffstat (limited to 'candle-core/src/tensor_cat.rs')
-rw-r--r-- | candle-core/src/tensor_cat.rs | 22 |
1 files changed, 10 insertions, 12 deletions
diff --git a/candle-core/src/tensor_cat.rs b/candle-core/src/tensor_cat.rs index 31cc8503..27ff7851 100644 --- a/candle-core/src/tensor_cat.rs +++ b/candle-core/src/tensor_cat.rs @@ -58,20 +58,18 @@ impl Tensor { } } } - if dim == 0 { + let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous()); + if all_contiguous { + Self::cat_contiguous(args, dim) + } else if dim == 0 { Self::cat0(args) } else { - let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous()); - if all_contiguous { - Self::cat_contiguous(args, dim) - } else { - let args: Vec<Tensor> = args - .iter() - .map(|a| a.as_ref().transpose(0, dim)) - .collect::<Result<Vec<_>>>()?; - let cat = Self::cat0(&args)?; - cat.transpose(0, dim) - } + let args: Vec<Tensor> = args + .iter() + .map(|a| a.as_ref().transpose(0, dim)) + .collect::<Result<Vec<_>>>()?; + let cat = Self::cat0(&args)?; + cat.transpose(0, dim) } } |