summaryrefslogtreecommitdiff
path: root/candle-nn/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-31 14:14:01 +0100
committerGitHub <noreply@github.com>2023-07-31 14:14:01 +0100
commitffeafbfc43307fe4e2daa3e3fdfe7afb781c5505 (patch)
tree110490e3906c4c7e25e84286f2a190c4d0df5638 /candle-nn/src
parentb3ea96b62bed2e347d63489f16172c11b8093950 (diff)
downloadcandle-ffeafbfc43307fe4e2daa3e3fdfe7afb781c5505.tar.gz
candle-ffeafbfc43307fe4e2daa3e3fdfe7afb781c5505.tar.bz2
candle-ffeafbfc43307fe4e2daa3e3fdfe7afb781c5505.zip
Make the nll op closer to the pytorch version + add a test. (#286)
Diffstat (limited to 'candle-nn/src')
-rw-r--r--candle-nn/src/loss.rs24
1 files changed, 22 insertions, 2 deletions
diff --git a/candle-nn/src/loss.rs b/candle-nn/src/loss.rs
index d388af6c..b380c426 100644
--- a/candle-nn/src/loss.rs
+++ b/candle-nn/src/loss.rs
@@ -1,8 +1,28 @@
use candle::{Result, Tensor};
+/// The negative loss likelihodd loss.
+///
+/// Arguments
+///
+/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
+/// of categories. This is expected to contain log probabilities.
+/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`.
+///
+/// The resulting tensor is a scalar containing the average value over the batch.
pub fn nll(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
- let b_sz = target.dim(0)?;
- inp.gather(target, 1)?
+ let b_sz = match target.dims() {
+ &[b_sz] => b_sz,
+ dims => candle::bail!("the target tensor should have a single dimension ({dims:?})"),
+ };
+ match inp.dims() {
+ &[inp_b_sz, _] => {
+ if inp_b_sz != b_sz {
+ candle::bail!("batch size mismatch between inp ({inp_b_sz}) and target ({b_sz})")
+ }
+ }
+ dims => candle::bail!("the target tensor should have two dimensions ({dims:?})"),
+ }
+ inp.gather(&target.unsqueeze(1)?, 1)?
.sum_all()?
.affine(-1f64 / b_sz as f64, 0.)
}