diff options
Diffstat (limited to 'candle-core/tests/tensor_tests.rs')
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 6af43196..cd68908f 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -877,6 +877,14 @@ fn broadcasting(device: &Device) -> Result<()> { Ok(()) } +fn randn(device: &Device) -> Result<()> { + let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?; + assert_eq!(tensor.dims(), [5, 3]); + let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?; + assert_eq!(tensor.dims(), [5, 3]); + Ok(()) +} + test_device!(zeros, zeros_cpu, zeros_gpu); test_device!(add_mul, add_mul_cpu, add_mul_gpu); test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu); @@ -899,6 +907,7 @@ test_device!(index_select, index_select_cpu, index_select_gpu); test_device!(index_add, index_add_cpu, index_add_gpu); test_device!(gather, gather_cpu, gather_gpu); test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu); +test_device!(randn, randn_cpu, randn_gpu); // There was originally a bug on the CPU implementation for randn // https://github.com/huggingface/candle/issues/381 |