summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorOgundepo Odunayo <ogundepoodunayo@gmail.com>2023-10-23 12:12:44 -0400
committerGitHub <noreply@github.com>2023-10-23 17:12:44 +0100
commit86e1803191e2ed44c57ad01807b29a886c0263bb (patch)
tree6ccabebecb1ebcb9402fa954f3d46dceb9e354bb /candle-nn
parent25c3cc4149304a4f6eec93b2f88aa9c241f8f696 (diff)
downloadcandle-86e1803191e2ed44c57ad01807b29a886c0263bb.tar.gz
candle-86e1803191e2ed44c57ad01807b29a886c0263bb.tar.bz2
candle-86e1803191e2ed44c57ad01807b29a886c0263bb.zip
Add Binary Cross Entropy With Logit Loss to nn crate (#1157)
* add bce with logit loss * add bce with logit loss * remove imports * fix tiny bug * add test documentation and refactor function * fix test cases and formatting
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/loss.rs22
-rw-r--r--candle-nn/tests/loss.rs47
2 files changed, 69 insertions, 0 deletions
diff --git a/candle-nn/src/loss.rs b/candle-nn/src/loss.rs
index 72451f83..fb1e11f4 100644
--- a/candle-nn/src/loss.rs
+++ b/candle-nn/src/loss.rs
@@ -48,3 +48,25 @@ pub fn cross_entropy(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
pub fn mse(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
(inp - target)?.sqr()?.mean_all()
}
+
+/// The binary cross-entropy with logit loss.
+///
+/// Arguments
+///
+/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
+/// of categories. This is expected to raw logits.
+/// * [target]: The ground truth labels as a tensor of u32 of dimension `N, C` where `N` is the batch size and `C` the number
+/// of categories.
+///
+/// The resulting tensor is a scalar containing the average value over the batch.
+pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
+ let inp = crate::ops::sigmoid(inp)?;
+
+ let left_side = target * inp.log()?;
+ let right_side = (target.affine(-1., 1.))? * inp.affine(-1., 1.)?.log()?;
+
+ let loss = left_side? + right_side?;
+ let loss = loss?.neg()?.mean_all()?;
+
+ Ok(loss)
+}
diff --git a/candle-nn/tests/loss.rs b/candle-nn/tests/loss.rs
index d772f176..ccfc029f 100644
--- a/candle-nn/tests/loss.rs
+++ b/candle-nn/tests/loss.rs
@@ -39,3 +39,50 @@ fn nll_and_cross_entropy() -> Result<()> {
assert_eq!(to_vec0_round(&loss, 4)?, 1.1312);
Ok(())
}
+
+/* Equivalent python code:
+import torch
+import torch.nn.functional as F
+
+inp = torch.Tensor([[ 2.3611, -0.8813, -0.5006, -0.2178],
+ [ 0.0419, 0.0763, -1.0457, -1.6692],
+ [-1.0494, 0.8111, 1.5723, 1.2315],
+ [ 1.3081, 0.6641, 1.1802, -0.2547],
+ [ 0.5292, 0.7636, 0.3692, -0.8318]])
+
+target = torch.Tensor([[0., 1., 0., 0.],
+ [0., 1., 0., 0.],
+ [0., 0., 0., 1.],
+ [1., 0., 0., 0.],
+ [0., 0., 1., 0.]])
+
+print(F.binary_cross_entropy_with_logits(inp, target))
+*/
+#[test]
+fn binary_cross_entropy_with_logit() -> Result<()> {
+ let cpu = Device::Cpu;
+
+ let inp = [
+ [2.3611f32, -0.8813, -0.5006, -0.2178],
+ [0.0419, 0.0763, -1.0457, -1.6692],
+ [-1.0494, 0.8111, 1.5723, 1.2315],
+ [1.3081, 0.6641, 1.1802, -0.2547],
+ [0.5292, 0.7636, 0.3692, -0.8318],
+ ];
+
+ let target = [
+ [0.0f32, 1., 0., 0.],
+ [0., 1., 0., 0.],
+ [0., 0., 0., 1.],
+ [1., 0., 0., 0.],
+ [0., 0., 1., 0.],
+ ];
+
+ let inp = Tensor::new(&inp, &cpu)?;
+ let target = Tensor::new(&target, &cpu)?;
+
+ let loss = candle_nn::loss::binary_cross_entropy_with_logit(&inp, &target)?;
+
+ assert_eq!(to_vec0_round(&loss, 4)?, 0.8224);
+ Ok(())
+}