summaryrefslogtreecommitdiff
path: root/candle-examples/examples/musicgen/t5_model.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/musicgen/t5_model.rs')
-rw-r--r--candle-examples/examples/musicgen/t5_model.rs2
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs
index 0444f360..2119cf9b 100644
--- a/candle-examples/examples/musicgen/t5_model.rs
+++ b/candle-examples/examples/musicgen/t5_model.rs
@@ -98,7 +98,7 @@ impl T5LayerNorm {
let dtype = xs.dtype();
let xs_f32 = xs.to_dtype(DType::F32)?;
let xs2_f32 = (&xs_f32 * &xs_f32)?;
- let sum_xs2_f32 = xs2_f32.sum(&[xs.rank() - 1])?;
+ let sum_xs2_f32 = xs2_f32.sum_keepdim(&[xs.rank() - 1])?;
let variance = xs2_f32.broadcast_div(&sum_xs2_f32)?;
let xs = (xs / (variance + self.variance_epsilon)?.sqrt()?)?;
let xs = xs.to_dtype(dtype)?;