diff options
Diffstat (limited to 'candle-nn/src/loss.rs')
-rw-r--r-- | candle-nn/src/loss.rs | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/candle-nn/src/loss.rs b/candle-nn/src/loss.rs index 72451f83..fb1e11f4 100644 --- a/candle-nn/src/loss.rs +++ b/candle-nn/src/loss.rs @@ -48,3 +48,25 @@ pub fn cross_entropy(inp: &Tensor, target: &Tensor) -> Result<Tensor> { pub fn mse(inp: &Tensor, target: &Tensor) -> Result<Tensor> { (inp - target)?.sqr()?.mean_all() } + +/// The binary cross-entropy with logit loss. +/// +/// Arguments +/// +/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number +/// of categories. This is expected to raw logits. +/// * [target]: The ground truth labels as a tensor of u32 of dimension `N, C` where `N` is the batch size and `C` the number +/// of categories. +/// +/// The resulting tensor is a scalar containing the average value over the batch. +pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result<Tensor> { + let inp = crate::ops::sigmoid(inp)?; + + let left_side = target * inp.log()?; + let right_side = (target.affine(-1., 1.))? * inp.affine(-1., 1.)?.log()?; + + let loss = left_side? + right_side?; + let loss = loss?.neg()?.mean_all()?; + + Ok(loss) +} |