summaryrefslogtreecommitdiff
path: root/candle-examples/examples/musicgen/t5_model.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-13 21:32:32 +0100
committerGitHub <noreply@github.com>2023-07-13 21:32:32 +0100
commit2bfa791336b320b96d392aba83cbd4cee87173e3 (patch)
treea3127719a64cf5cfbf38f5f8be859afd2dc6118e /candle-examples/examples/musicgen/t5_model.rs
parent57be3638d8c10304629f6859d183fb192858f3a3 (diff)
downloadcandle-2bfa791336b320b96d392aba83cbd4cee87173e3.tar.gz
candle-2bfa791336b320b96d392aba83cbd4cee87173e3.tar.bz2
candle-2bfa791336b320b96d392aba83cbd4cee87173e3.zip
Use the same default as pytorch for sum. (#164)
Diffstat (limited to 'candle-examples/examples/musicgen/t5_model.rs')
-rw-r--r--candle-examples/examples/musicgen/t5_model.rs2
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs
index 0444f360..2119cf9b 100644
--- a/candle-examples/examples/musicgen/t5_model.rs
+++ b/candle-examples/examples/musicgen/t5_model.rs
@@ -98,7 +98,7 @@ impl T5LayerNorm {
let dtype = xs.dtype();
let xs_f32 = xs.to_dtype(DType::F32)?;
let xs2_f32 = (&xs_f32 * &xs_f32)?;
- let sum_xs2_f32 = xs2_f32.sum(&[xs.rank() - 1])?;
+ let sum_xs2_f32 = xs2_f32.sum_keepdim(&[xs.rank() - 1])?;
let variance = xs2_f32.broadcast_div(&sum_xs2_f32)?;
let xs = (xs / (variance + self.variance_epsilon)?.sqrt()?)?;
let xs = xs.to_dtype(dtype)?;