diff options
Diffstat (limited to 'candle-nn/src')
-rw-r--r-- | candle-nn/src/layer_norm.rs | 8 | ||||
-rw-r--r-- | candle-nn/src/linear.rs | 5 |
2 files changed, 7 insertions, 6 deletions
diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs index 08e2f628..d2e80a82 100644 --- a/candle-nn/src/layer_norm.rs +++ b/candle-nn/src/layer_norm.rs @@ -28,7 +28,7 @@ //! ``` //! //! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450 -use candle::{DType, Result, Tensor}; +use candle::{DType, Result, Tensor, D}; #[derive(Debug, Clone, Copy, PartialEq)] pub struct LayerNormConfig { @@ -104,15 +104,15 @@ impl crate::Module for LayerNorm { DType::F16 | DType::BF16 => DType::F32, d => d, }; - let (_bsize, _seq_len, hidden_size) = x.dims3()?; + let hidden_size = x.dim(D::Minus1)?; let x = x.to_dtype(internal_dtype)?; let x = if self.remove_mean { - let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?; + let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?; x.broadcast_sub(&mean_x)? } else { x }; - let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; + let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; let x = x_normed.to_dtype(x_dtype)?.broadcast_mul(&self.weight)?; match &self.bias { diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs index 7028f68c..de335964 100644 --- a/candle-nn/src/linear.rs +++ b/candle-nn/src/linear.rs @@ -41,8 +41,9 @@ impl Linear { impl super::Module for Linear { fn forward(&self, x: &Tensor) -> candle::Result<Tensor> { - let w = match x.dims() { - &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, + let w = match *x.dims() { + [b1, b2, _, _] => self.weight.broadcast_left((b1, b2))?.t()?, + [bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, _ => self.weight.t()?, }; let x = x.matmul(&w)?; |