summaryrefslogtreecommitdiff
path: root/candle-core/tests/tensor_tests.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/tests/tensor_tests.rs')
-rw-r--r--candle-core/tests/tensor_tests.rs9
1 files changed, 9 insertions, 0 deletions
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index 6af43196..cd68908f 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -877,6 +877,14 @@ fn broadcasting(device: &Device) -> Result<()> {
Ok(())
}
+fn randn(device: &Device) -> Result<()> {
+ let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?;
+ assert_eq!(tensor.dims(), [5, 3]);
+ let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?;
+ assert_eq!(tensor.dims(), [5, 3]);
+ Ok(())
+}
+
test_device!(zeros, zeros_cpu, zeros_gpu);
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
@@ -899,6 +907,7 @@ test_device!(index_select, index_select_cpu, index_select_gpu);
test_device!(index_add, index_add_cpu, index_add_gpu);
test_device!(gather, gather_cpu, gather_gpu);
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
+test_device!(randn, randn_cpu, randn_gpu);
// There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381