diff options
Diffstat (limited to 'candle-nn/src/linear.rs')
-rw-r--r-- | candle-nn/src/linear.rs | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs index 7028f68c..de335964 100644 --- a/candle-nn/src/linear.rs +++ b/candle-nn/src/linear.rs @@ -41,8 +41,9 @@ impl Linear { impl super::Module for Linear { fn forward(&self, x: &Tensor) -> candle::Result<Tensor> { - let w = match x.dims() { - &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, + let w = match *x.dims() { + [b1, b2, _, _] => self.weight.broadcast_left((b1, b2))?.t()?, + [bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, _ => self.weight.t()?, }; let x = x.matmul(&w)?; |