diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-01 17:23:07 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-01 17:23:07 +0100 |
commit | a27239f3d9b77ad4c300de38d43c6ad64d6b5ea6 (patch) | |
tree | 8f31406d35aff7b5c6aecbfbdac773cf31574fce /candle-core/src | |
parent | babee9f011805f59868b67053bdb8cce0e221e18 (diff) | |
download | candle-a27239f3d9b77ad4c300de38d43c6ad64d6b5ea6.tar.gz candle-a27239f3d9b77ad4c300de38d43c6ad64d6b5ea6.tar.bz2 candle-a27239f3d9b77ad4c300de38d43c6ad64d6b5ea6.zip |
Add training for the llama2.c example (#296)
* Rework the commands and run inference by default.
* Add the training module and load the training dataset.
* Random dataset iterator.
* Proper valid-loss computation.
* Compute the evaluation loss.
* Add more substance to the training loop.
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/error.rs | 8 | ||||
-rw-r--r-- | candle-core/src/lib.rs | 2 |
2 files changed, 9 insertions, 1 deletions
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 30d06239..35a33032 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -228,3 +228,11 @@ macro_rules! bail { return Err($crate::Error::Msg(format!($fmt, $($arg)*).into()).bt()) }; } + +pub fn zip<T, U>(r1: Result<T>, r2: Result<U>) -> Result<(T, U)> { + match (r1, r2) { + (Ok(r1), Ok(r2)) => Ok((r1, r2)), + (Err(e), _) => Err(e), + (_, Err(e)) => Err(e), + } +} diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 95cc189c..52244052 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -44,7 +44,7 @@ mod device; pub mod display; mod dtype; mod dummy_cuda_backend; -mod error; +pub mod error; mod indexer; pub mod layout; #[cfg(feature = "mkl")] |