summaryrefslogtreecommitdiff
path: root/candle-nn/src/loss.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src/loss.rs')
-rw-r--r--candle-nn/src/loss.rs22
1 files changed, 22 insertions, 0 deletions
diff --git a/candle-nn/src/loss.rs b/candle-nn/src/loss.rs
index 72451f83..fb1e11f4 100644
--- a/candle-nn/src/loss.rs
+++ b/candle-nn/src/loss.rs
@@ -48,3 +48,25 @@ pub fn cross_entropy(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
pub fn mse(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
(inp - target)?.sqr()?.mean_all()
}
+
+/// The binary cross-entropy with logit loss.
+///
+/// Arguments
+///
+/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
+/// of categories. This is expected to raw logits.
+/// * [target]: The ground truth labels as a tensor of u32 of dimension `N, C` where `N` is the batch size and `C` the number
+/// of categories.
+///
+/// The resulting tensor is a scalar containing the average value over the batch.
+pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
+ let inp = crate::ops::sigmoid(inp)?;
+
+ let left_side = target * inp.log()?;
+ let right_side = (target.affine(-1., 1.))? * inp.affine(-1., 1.)?.log()?;
+
+ let loss = left_side? + right_side?;
+ let loss = loss?.neg()?.mean_all()?;
+
+ Ok(loss)
+}