summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-01 17:23:07 +0100
committerGitHub <noreply@github.com>2023-08-01 17:23:07 +0100
commita27239f3d9b77ad4c300de38d43c6ad64d6b5ea6 (patch)
tree8f31406d35aff7b5c6aecbfbdac773cf31574fce /candle-core/src
parentbabee9f011805f59868b67053bdb8cce0e221e18 (diff)
downloadcandle-a27239f3d9b77ad4c300de38d43c6ad64d6b5ea6.tar.gz
candle-a27239f3d9b77ad4c300de38d43c6ad64d6b5ea6.tar.bz2
candle-a27239f3d9b77ad4c300de38d43c6ad64d6b5ea6.zip
Add training for the llama2.c example (#296)
* Rework the commands and run inference by default. * Add the training module and load the training dataset. * Random dataset iterator. * Proper valid-loss computation. * Compute the evaluation loss. * Add more substance to the training loop.
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/error.rs8
-rw-r--r--candle-core/src/lib.rs2
2 files changed, 9 insertions, 1 deletions
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs
index 30d06239..35a33032 100644
--- a/candle-core/src/error.rs
+++ b/candle-core/src/error.rs
@@ -228,3 +228,11 @@ macro_rules! bail {
return Err($crate::Error::Msg(format!($fmt, $($arg)*).into()).bt())
};
}
+
+pub fn zip<T, U>(r1: Result<T>, r2: Result<U>) -> Result<(T, U)> {
+ match (r1, r2) {
+ (Ok(r1), Ok(r2)) => Ok((r1, r2)),
+ (Err(e), _) => Err(e),
+ (_, Err(e)) => Err(e),
+ }
+}
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs
index 95cc189c..52244052 100644
--- a/candle-core/src/lib.rs
+++ b/candle-core/src/lib.rs
@@ -44,7 +44,7 @@ mod device;
pub mod display;
mod dtype;
mod dummy_cuda_backend;
-mod error;
+pub mod error;
mod indexer;
pub mod layout;
#[cfg(feature = "mkl")]