summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/tensor.rs24
-rw-r--r--candle-core/tests/tensor_tests.rs35
2 files changed, 59 insertions, 0 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index d51a3db7..f6b1698c 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -2450,6 +2450,30 @@ impl Tensor {
Ok(naxis as usize)
}
}
+
+ /// Returns a lower triangular matrix of ones of size n by n.
+ pub fn tril2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
+ let t = Tensor::arange(0u32, n as u32, device)?;
+ let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
+ let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
+ t1.le(&t2)?.to_dtype(dtype)
+ }
+
+ /// Returns an upper triangular matrix of ones of size n by n.
+ pub fn triu2(n: usize, dtype: DType, device: &Device) -> Result<Self> {
+ let t = Tensor::arange(0u32, n as u32, device)?;
+ let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
+ let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
+ t1.ge(&t2)?.to_dtype(dtype)
+ }
+
+ /// Returns a matrix with a diagonal of ones of size n by n.
+ pub fn eye(n: usize, dtype: DType, device: &Device) -> Result<Self> {
+ let t = Tensor::arange(0u32, n as u32, device)?;
+ let t1 = t.reshape((1, n))?.broadcast_as((n, n))?;
+ let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?;
+ t1.eq(&t2)?.to_dtype(dtype)
+ }
}
macro_rules! bin_trait {
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index cc44ce94..c8b255dd 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -1134,3 +1134,38 @@ fn i64_abs() -> Result<()> {
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
Ok(())
}
+
+#[test]
+fn tril_triu_eye() -> Result<()> {
+ let t = Tensor::tril2(4, DType::F32, &Device::Cpu)?;
+ assert_eq!(
+ t.to_vec2::<f32>()?,
+ [
+ [1.0, 0.0, 0.0, 0.0],
+ [1.0, 1.0, 0.0, 0.0],
+ [1.0, 1.0, 1.0, 0.0],
+ [1.0, 1.0, 1.0, 1.0]
+ ],
+ );
+ let t = Tensor::triu2(4, DType::F32, &Device::Cpu)?;
+ assert_eq!(
+ t.to_vec2::<f32>()?,
+ [
+ [1.0, 1.0, 1.0, 1.0],
+ [0.0, 1.0, 1.0, 1.0],
+ [0.0, 0.0, 1.0, 1.0],
+ [0.0, 0.0, 0.0, 1.0]
+ ]
+ );
+ let t = Tensor::eye(4, DType::F32, &Device::Cpu)?;
+ assert_eq!(
+ t.to_vec2::<f32>()?,
+ [
+ [1.0, 0.0, 0.0, 0.0],
+ [0.0, 1.0, 0.0, 0.0],
+ [0.0, 0.0, 1.0, 0.0],
+ [0.0, 0.0, 0.0, 1.0]
+ ]
+ );
+ Ok(())
+}