diff options
-rw-r--r-- | candle-core/src/tensor.rs | 24 | ||||
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 35 |
2 files changed, 59 insertions, 0 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index d51a3db7..f6b1698c 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2450,6 +2450,30 @@ impl Tensor { Ok(naxis as usize) } } + + /// Returns a lower triangular matrix of ones of size n by n. + pub fn tril2(n: usize, dtype: DType, device: &Device) -> Result<Self> { + let t = Tensor::arange(0u32, n as u32, device)?; + let t1 = t.reshape((1, n))?.broadcast_as((n, n))?; + let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?; + t1.le(&t2)?.to_dtype(dtype) + } + + /// Returns an upper triangular matrix of ones of size n by n. + pub fn triu2(n: usize, dtype: DType, device: &Device) -> Result<Self> { + let t = Tensor::arange(0u32, n as u32, device)?; + let t1 = t.reshape((1, n))?.broadcast_as((n, n))?; + let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?; + t1.ge(&t2)?.to_dtype(dtype) + } + + /// Returns a matrix with a diagonal of ones of size n by n. + pub fn eye(n: usize, dtype: DType, device: &Device) -> Result<Self> { + let t = Tensor::arange(0u32, n as u32, device)?; + let t1 = t.reshape((1, n))?.broadcast_as((n, n))?; + let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?; + t1.eq(&t2)?.to_dtype(dtype) + } } macro_rules! bin_trait { diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index cc44ce94..c8b255dd 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1134,3 +1134,38 @@ fn i64_abs() -> Result<()> { assert_eq!(t.to_vec1::<i64>()?, [42, 1337]); Ok(()) } + +#[test] +fn tril_triu_eye() -> Result<()> { + let t = Tensor::tril2(4, DType::F32, &Device::Cpu)?; + assert_eq!( + t.to_vec2::<f32>()?, + [ + [1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0] + ], + ); + let t = Tensor::triu2(4, DType::F32, &Device::Cpu)?; + assert_eq!( + t.to_vec2::<f32>()?, + [ + [1.0, 1.0, 1.0, 1.0], + [0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 1.0] + ] + ); + let t = Tensor::eye(4, DType::F32, &Device::Cpu)?; + assert_eq!( + t.to_vec2::<f32>()?, + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0] + ] + ); + Ok(()) +} |