summaryrefslogtreecommitdiff
path: root/candle-core/tests/tensor_tests.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/tests/tensor_tests.rs')
-rw-r--r--candle-core/tests/tensor_tests.rs26
1 files changed, 26 insertions, 0 deletions
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index dbe0dd6a..d3eede48 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -8,6 +8,31 @@ fn zeros(device: &Device) -> Result<()> {
Ok(())
}
+fn ones(device: &Device) -> Result<()> {
+ assert_eq!(
+ Tensor::ones((2, 3), DType::U8, device)?.to_vec2::<u8>()?,
+ [[1, 1, 1], [1, 1, 1]],
+ );
+ assert_eq!(
+ Tensor::ones((2, 3), DType::U32, device)?.to_vec2::<u32>()?,
+ [[1, 1, 1], [1, 1, 1]],
+ );
+ assert_eq!(
+ Tensor::ones((2, 3), DType::I64, device)?.to_vec2::<i64>()?,
+ [[1, 1, 1], [1, 1, 1]],
+ );
+ assert_eq!(
+ Tensor::ones((2, 3), DType::F32, device)?.to_vec2::<f32>()?,
+ [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
+ );
+ assert_eq!(
+ Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
+ [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
+ );
+
+ Ok(())
+}
+
fn add_mul(device: &Device) -> Result<()> {
let tensor = Tensor::new(&[3f32, 1., 4.], device)?;
let dim1 = tensor.dims1()?;
@@ -966,6 +991,7 @@ fn randn(device: &Device) -> Result<()> {
}
test_device!(zeros, zeros_cpu, zeros_gpu);
+test_device!(ones, ones_cpu, ones_gpu);
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
test_device!(narrow, narrow_cpu, narrow_gpu);