summaryrefslogtreecommitdiff
path: root/candle-core/tests/tensor_tests.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/tests/tensor_tests.rs')
-rw-r--r--candle-core/tests/tensor_tests.rs7
1 files changed, 5 insertions, 2 deletions
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index a126d634..95ce982a 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -270,7 +270,11 @@ fn cat(device: &Device) -> Result<()> {
[2.0, 7.0, 1.0, 8.0, 2.0]
]
);
- // TODO: This is not the expected answer, to be fixed!
+ // PyTorch equivalent:
+ // import torch
+ // t1 = torch.tensor([[3, 1, 4, 1, 5], [2, 7, 1, 8, 2]])
+ // t2 = torch.tensor([[5]*5, [2, 7, 1, 8, 2]])
+ // torch.cat([t1.t(), t2.t()], dim=1).t()
assert_eq!(
Tensor::cat(&[&t1.t()?, &t2.t()?], 1)?
.t()?
@@ -282,7 +286,6 @@ fn cat(device: &Device) -> Result<()> {
[2.0, 7.0, 1.0, 8.0, 2.0]
]
);
- // TODO: This is not the expected answer, to be fixed!
assert_eq!(
Tensor::cat(&[&t1, &t2], 1)?.to_vec2::<f32>()?,
[