diff options
Diffstat (limited to 'candle-core/tests/tensor_tests.rs')
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index a126d634..95ce982a 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -270,7 +270,11 @@ fn cat(device: &Device) -> Result<()> { [2.0, 7.0, 1.0, 8.0, 2.0] ] ); - // TODO: This is not the expected answer, to be fixed! + // PyTorch equivalent: + // import torch + // t1 = torch.tensor([[3, 1, 4, 1, 5], [2, 7, 1, 8, 2]]) + // t2 = torch.tensor([[5]*5, [2, 7, 1, 8, 2]]) + // torch.cat([t1.t(), t2.t()], dim=1).t() assert_eq!( Tensor::cat(&[&t1.t()?, &t2.t()?], 1)? .t()? @@ -282,7 +286,6 @@ fn cat(device: &Device) -> Result<()> { [2.0, 7.0, 1.0, 8.0, 2.0] ] ); - // TODO: This is not the expected answer, to be fixed! assert_eq!( Tensor::cat(&[&t1, &t2], 1)?.to_vec2::<f32>()?, [ |