summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/dtype.rs11
-rw-r--r--candle-core/src/shape.rs6
-rw-r--r--candle-core/src/tensor.rs1
3 files changed, 18 insertions, 0 deletions
diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs
index adfc4a3c..c7a1567f 100644
--- a/candle-core/src/dtype.rs
+++ b/candle-core/src/dtype.rs
@@ -1,15 +1,24 @@
+//! Types for elements that can be stored and manipulated using tensors.
#![allow(clippy::redundant_closure_call)]
use crate::backend::BackendStorage;
use crate::{CpuStorage, Error, Result};
+/// The different types of elements allowed in tensors.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum DType {
+ // Unsigned 8 bits integer.
U8,
+ // Unsigned 32 bits integer.
U32,
+ // Signed 64 bits integer.
I64,
+ // Brain floating-point using half precision (16 bits).
BF16,
+ // Floating-point using half precision (16 bits).
F16,
+ // Floating-point using single precision (32 bits).
F32,
+ // Floating-point using double precision (64 bits).
F64,
}
@@ -33,6 +42,7 @@ impl std::str::FromStr for DType {
}
impl DType {
+ /// String representation for dtypes.
pub fn as_str(&self) -> &'static str {
match self {
Self::U8 => "u8",
@@ -45,6 +55,7 @@ impl DType {
}
}
+ /// The size used by each element in bytes, i.e. 1 for `U8`, 4 for `F32`.
pub fn size_in_bytes(&self) -> usize {
match self {
Self::U8 => 1,
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs
index aea8b887..db0fe98a 100644
--- a/candle-core/src/shape.rs
+++ b/candle-core/src/shape.rs
@@ -1,3 +1,4 @@
+//! The shape of a tensor is a tuple with the size of each of its dimensions.
#![allow(clippy::redundant_closure_call)]
use crate::{Error, Result};
@@ -119,6 +120,7 @@ impl Shape {
Self(dims.to_vec())
}
+ /// The rank is the number of dimensions, 0 for a scalar value, 1 for a vector, etc.
pub fn rank(&self) -> usize {
self.0.len()
}
@@ -127,10 +129,12 @@ impl Shape {
self.0
}
+ /// The dimensions as a slice of `usize`.
pub fn dims(&self) -> &[usize] {
&self.0
}
+ /// The total number of elements, this is the product of all dimension sizes.
pub fn elem_count(&self) -> usize {
self.0.iter().product()
}
@@ -182,6 +186,8 @@ impl Shape {
true
}
+ /// Modifies the shape by adding a list of additional dimensions at the end of the existing
+ /// dimensions.
pub fn extend(mut self, additional_dims: &[usize]) -> Self {
self.0.extend(additional_dims);
self
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index e181f240..1eca694c 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -1,3 +1,4 @@
+//! Tensors are N-dimenional matrixes of elements using a single data type.
#![allow(clippy::redundant_closure_call)]
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{