diff options
author | Ogundepo Odunayo <ogundepoodunayo@gmail.com> | 2023-10-23 12:12:44 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-23 17:12:44 +0100 |
commit | 86e1803191e2ed44c57ad01807b29a886c0263bb (patch) | |
tree | 6ccabebecb1ebcb9402fa954f3d46dceb9e354bb /candle-nn | |
parent | 25c3cc4149304a4f6eec93b2f88aa9c241f8f696 (diff) | |
download | candle-86e1803191e2ed44c57ad01807b29a886c0263bb.tar.gz candle-86e1803191e2ed44c57ad01807b29a886c0263bb.tar.bz2 candle-86e1803191e2ed44c57ad01807b29a886c0263bb.zip |
Add Binary Cross Entropy With Logit Loss to nn crate (#1157)
* add bce with logit loss
* add bce with logit loss
* remove imports
* fix tiny bug
* add test documentation and refactor function
* fix test cases and formatting
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/loss.rs | 22 | ||||
-rw-r--r-- | candle-nn/tests/loss.rs | 47 |
2 files changed, 69 insertions, 0 deletions
diff --git a/candle-nn/src/loss.rs b/candle-nn/src/loss.rs index 72451f83..fb1e11f4 100644 --- a/candle-nn/src/loss.rs +++ b/candle-nn/src/loss.rs @@ -48,3 +48,25 @@ pub fn cross_entropy(inp: &Tensor, target: &Tensor) -> Result<Tensor> { pub fn mse(inp: &Tensor, target: &Tensor) -> Result<Tensor> { (inp - target)?.sqr()?.mean_all() } + +/// The binary cross-entropy with logit loss. +/// +/// Arguments +/// +/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number +/// of categories. This is expected to raw logits. +/// * [target]: The ground truth labels as a tensor of u32 of dimension `N, C` where `N` is the batch size and `C` the number +/// of categories. +/// +/// The resulting tensor is a scalar containing the average value over the batch. +pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result<Tensor> { + let inp = crate::ops::sigmoid(inp)?; + + let left_side = target * inp.log()?; + let right_side = (target.affine(-1., 1.))? * inp.affine(-1., 1.)?.log()?; + + let loss = left_side? + right_side?; + let loss = loss?.neg()?.mean_all()?; + + Ok(loss) +} diff --git a/candle-nn/tests/loss.rs b/candle-nn/tests/loss.rs index d772f176..ccfc029f 100644 --- a/candle-nn/tests/loss.rs +++ b/candle-nn/tests/loss.rs @@ -39,3 +39,50 @@ fn nll_and_cross_entropy() -> Result<()> { assert_eq!(to_vec0_round(&loss, 4)?, 1.1312); Ok(()) } + +/* Equivalent python code: +import torch +import torch.nn.functional as F + +inp = torch.Tensor([[ 2.3611, -0.8813, -0.5006, -0.2178], + [ 0.0419, 0.0763, -1.0457, -1.6692], + [-1.0494, 0.8111, 1.5723, 1.2315], + [ 1.3081, 0.6641, 1.1802, -0.2547], + [ 0.5292, 0.7636, 0.3692, -0.8318]]) + +target = torch.Tensor([[0., 1., 0., 0.], + [0., 1., 0., 0.], + [0., 0., 0., 1.], + [1., 0., 0., 0.], + [0., 0., 1., 0.]]) + +print(F.binary_cross_entropy_with_logits(inp, target)) +*/ +#[test] +fn binary_cross_entropy_with_logit() -> Result<()> { + let cpu = Device::Cpu; + + let inp = [ + [2.3611f32, -0.8813, -0.5006, -0.2178], + [0.0419, 0.0763, -1.0457, -1.6692], + [-1.0494, 0.8111, 1.5723, 1.2315], + [1.3081, 0.6641, 1.1802, -0.2547], + [0.5292, 0.7636, 0.3692, -0.8318], + ]; + + let target = [ + [0.0f32, 1., 0., 0.], + [0., 1., 0., 0.], + [0., 0., 0., 1.], + [1., 0., 0., 0.], + [0., 0., 1., 0.], + ]; + + let inp = Tensor::new(&inp, &cpu)?; + let target = Tensor::new(&target, &cpu)?; + + let loss = candle_nn::loss::binary_cross_entropy_with_logit(&inp, &target)?; + + assert_eq!(to_vec0_round(&loss, 4)?, 0.8224); + Ok(()) +} |