summaryrefslogtreecommitdiff
path: root/candle-nn/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src')
-rw-r--r--candle-nn/src/layer_norm.rs8
-rw-r--r--candle-nn/src/linear.rs5
2 files changed, 7 insertions, 6 deletions
diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs
index 08e2f628..d2e80a82 100644
--- a/candle-nn/src/layer_norm.rs
+++ b/candle-nn/src/layer_norm.rs
@@ -28,7 +28,7 @@
//! ```
//!
//! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450
-use candle::{DType, Result, Tensor};
+use candle::{DType, Result, Tensor, D};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct LayerNormConfig {
@@ -104,15 +104,15 @@ impl crate::Module for LayerNorm {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
- let (_bsize, _seq_len, hidden_size) = x.dims3()?;
+ let hidden_size = x.dim(D::Minus1)?;
let x = x.to_dtype(internal_dtype)?;
let x = if self.remove_mean {
- let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
+ let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
x.broadcast_sub(&mean_x)?
} else {
x
};
- let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
+ let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
let x = x_normed.to_dtype(x_dtype)?.broadcast_mul(&self.weight)?;
match &self.bias {
diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs
index 7028f68c..de335964 100644
--- a/candle-nn/src/linear.rs
+++ b/candle-nn/src/linear.rs
@@ -41,8 +41,9 @@ impl Linear {
impl super::Module for Linear {
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
- let w = match x.dims() {
- &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
+ let w = match *x.dims() {
+ [b1, b2, _, _] => self.weight.broadcast_left((b1, b2))?.t()?,
+ [bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
_ => self.weight.t()?,
};
let x = x.matmul(&w)?;