summaryrefslogtreecommitdiff
path: root/candle-pyo3/src
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2023-07-01 20:55:15 +0100
committerlaurent <laurent.mazare@gmail.com>2023-07-01 20:55:15 +0100
commit42d1a52d01f5f10e3a04257cb4612225f08e1321 (patch)
tree9e9ead92132df169c61f137cb39d77a9d7d62969 /candle-pyo3/src
parent52db2a6849659f0029a6fd136c47e945d8eef50f (diff)
downloadcandle-42d1a52d01f5f10e3a04257cb4612225f08e1321.tar.gz
candle-42d1a52d01f5f10e3a04257cb4612225f08e1321.tar.bz2
candle-42d1a52d01f5f10e3a04257cb4612225f08e1321.zip
Add two methods.
Diffstat (limited to 'candle-pyo3/src')
-rw-r--r--candle-pyo3/src/lib.rs14
1 files changed, 14 insertions, 0 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index be9e427d..e1ce7f97 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -25,9 +25,23 @@ impl PyTensor {
Ok(Self(Tensor::new(f, &Cpu).map_err(wrap_err)?))
}
+ #[getter]
+ fn shape(&self) -> Vec<usize> {
+ self.0.dims().to_vec()
+ }
+
+ #[getter]
+ fn rank(&self) -> usize {
+ self.0.rank()
+ }
+
fn __repr__(&self) -> String {
format!("{}", self.0)
}
+
+ fn __str__(&self) -> String {
+ self.__repr__()
+ }
}
#[pyfunction]