diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-31 14:14:01 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-31 14:14:01 +0100 |
commit | ffeafbfc43307fe4e2daa3e3fdfe7afb781c5505 (patch) | |
tree | 110490e3906c4c7e25e84286f2a190c4d0df5638 /candle-nn/src | |
parent | b3ea96b62bed2e347d63489f16172c11b8093950 (diff) | |
download | candle-ffeafbfc43307fe4e2daa3e3fdfe7afb781c5505.tar.gz candle-ffeafbfc43307fe4e2daa3e3fdfe7afb781c5505.tar.bz2 candle-ffeafbfc43307fe4e2daa3e3fdfe7afb781c5505.zip |
Make the nll op closer to the pytorch version + add a test. (#286)
Diffstat (limited to 'candle-nn/src')
-rw-r--r-- | candle-nn/src/loss.rs | 24 |
1 files changed, 22 insertions, 2 deletions
diff --git a/candle-nn/src/loss.rs b/candle-nn/src/loss.rs index d388af6c..b380c426 100644 --- a/candle-nn/src/loss.rs +++ b/candle-nn/src/loss.rs @@ -1,8 +1,28 @@ use candle::{Result, Tensor}; +/// The negative loss likelihodd loss. +/// +/// Arguments +/// +/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number +/// of categories. This is expected to contain log probabilities. +/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`. +/// +/// The resulting tensor is a scalar containing the average value over the batch. pub fn nll(inp: &Tensor, target: &Tensor) -> Result<Tensor> { - let b_sz = target.dim(0)?; - inp.gather(target, 1)? + let b_sz = match target.dims() { + &[b_sz] => b_sz, + dims => candle::bail!("the target tensor should have a single dimension ({dims:?})"), + }; + match inp.dims() { + &[inp_b_sz, _] => { + if inp_b_sz != b_sz { + candle::bail!("batch size mismatch between inp ({inp_b_sz}) and target ({b_sz})") + } + } + dims => candle::bail!("the target tensor should have two dimensions ({dims:?})"), + } + inp.gather(&target.unsqueeze(1)?, 1)? .sum_all()? .affine(-1f64 / b_sz as f64, 0.) } |