summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/tensor.rs10
-rw-r--r--candle-core/tests/tensor_tests.rs9
2 files changed, 19 insertions, 0 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index e1cae41c..e6e7b415 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -361,6 +361,16 @@ impl Tensor {
Self::new_impl(array, shape, device, false)
}
+ /// Returns a new tensor with all the elements having the same specified value. Note that
+ /// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed.
+ pub fn full<D: crate::WithDType, S: Into<Shape>>(
+ value: D,
+ shape: S,
+ device: &Device,
+ ) -> Result<Self> {
+ Self::from_vec_impl(vec![value], (), device, false)?.broadcast_as(shape)
+ }
+
/// Creates a new 1D tensor from an iterator.
pub fn from_iter<D: crate::WithDType>(
iter: impl IntoIterator<Item = D>,
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index a4548d56..e83fb55b 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -32,6 +32,14 @@ fn ones(device: &Device) -> Result<()> {
Ok(())
}
+fn full(device: &Device) -> Result<()> {
+ assert_eq!(
+ Tensor::full(42u32, (2, 3), device)?.to_vec2::<u32>()?,
+ [[42, 42, 42], [42, 42, 42]],
+ );
+ Ok(())
+}
+
fn arange(device: &Device) -> Result<()> {
assert_eq!(
Tensor::arange(0u8, 5u8, device)?.to_vec1::<u8>()?,
@@ -1072,6 +1080,7 @@ fn randn(device: &Device) -> Result<()> {
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
+test_device!(full, full_cpu, full_gpu, full_metal);
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);