summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-28 08:30:35 +0200
committerGitHub <noreply@github.com>2024-04-28 08:30:35 +0200
commite5c8b88f90763073fc927ee232bda30fcbc05595 (patch)
tree4abf7f650b791ce7b64c2ebe7a9f88703771bf09 /candle-nn
parent805f3be8e1f28135b015ddebbe6c8ef3a8c53d13 (diff)
downloadcandle-e5c8b88f90763073fc927ee232bda30fcbc05595.tar.gz
candle-e5c8b88f90763073fc927ee232bda30fcbc05595.tar.bz2
candle-e5c8b88f90763073fc927ee232bda30fcbc05595.zip
Apply the cast before the scaling. (#2135)
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/ops.rs2
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index 1dac8c3b..7fc26c3f 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -70,7 +70,7 @@ pub fn dropout(xs: &Tensor, drop_p: f32) -> Result<Tensor> {
let rand = Tensor::rand(0f32, 1f32, xs.shape(), xs.device())?;
let scale = 1.0 / (1.0 - drop_p as f64);
let drop_p = Tensor::new(drop_p, xs.device())?.broadcast_as(xs.shape())?;
- let mask = (rand.ge(&drop_p)? * scale)?.to_dtype(xs.dtype())?;
+ let mask = (rand.ge(&drop_p)?.to_dtype(xs.dtype())? * scale)?;
xs * mask
}