summaryrefslogtreecommitdiff
path: root/candle-nn/src/linear.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src/linear.rs')
-rw-r--r--candle-nn/src/linear.rs5
1 files changed, 3 insertions, 2 deletions
diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs
index 7028f68c..de335964 100644
--- a/candle-nn/src/linear.rs
+++ b/candle-nn/src/linear.rs
@@ -41,8 +41,9 @@ impl Linear {
impl super::Module for Linear {
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
- let w = match x.dims() {
- &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
+ let w = match *x.dims() {
+ [b1, b2, _, _] => self.weight.broadcast_left((b1, b2))?.t()?,
+ [bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
_ => self.weight.t()?,
};
let x = x.matmul(&w)?;