summaryrefslogtreecommitdiff
path: root/candle-nn/src/encoding.rs
Commit message (Collapse)AuthorAgeFilesLines
* Fix clippy lints for 1.76. (#1682)Laurent Mazare2024-02-081-1/+1
|
* Simplify the one-hot implementation, support arbitrary rank. (#1514)Laurent Mazare2024-01-011-181/+38
| | | | | * Simplify the one-hot implementation, support arbitrary rank. * More cleanup.
* Add one-hot/cold encoding (#1489)Ryan Tate2024-01-011-0/+293
* add one-hot encoding * one_hot: improve error handling, use generic to_vecN::<D> Bails if the index value is equal to or greater than the depth value, which would result in an out-of-bounds error. A redundant check is added to ensure the index value does not exceed the length of the one-hot matrix size, which would also result in an out-of-bounds error. Bails if the index value is less than -1. If the index value is -1, then it ignores the setting of the on_value for the index value. Only values that are less than -1 are considered errors. * one-hot: use two generics, one_hot::<I, O>, for input and output data types Separating the input and output data types allows the input tensor indices to be a different data type than the output encoded tensor data type. For example, one_hot::<i64, u8>(...) will take an input tensor of i64 values and encode the output tensor using u8 values. The generic I::DTYPE must match the data type of the input indices, otherwise the method will bail. Additionally, this method adds an `allow_f64` option to enable the input indices data type to be f64 values. f64 values are disabled by default. TODO: indices data type and the generic I data type are currently not compile-time checked. * one_hot: remove input generic, use indices dtype matching This commit removes the to_f64() type cast and explicitly matches the DType from the input tensor. Currently, only U8, U32 and I64 is supported for input tensors. The match arms on the dtype is verbose. It would be nice to use a generic type with the WithDtype traitbound to pass to the to_vecN method and then return an inner value. Open to suggestions for better approaches here to reduce the match arm verbosity. * one_hot: use flat_map iterator over dims instead of nested for loop This commit replaces the nested for loops with an flat map iter over the dimensions of the input tensor. This commit also adds a test for a rank 3 input tensor. * one_hot: use mandatory on/off-values, remove const msgs This commit also updates doc tests, comments and test cases. * Small cleanups. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>