diff options
Diffstat (limited to 'candle-examples/examples/bert/main.rs')
-rw-r--r-- | candle-examples/examples/bert/main.rs | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 4396326d..11d01a6a 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -386,12 +386,12 @@ impl BertSelfAttention { let attention_scores = query_layer.matmul(&key_layer.t()?)?; let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?; - let attention_probs = attention_scores.softmax(attention_scores.rank() - 1)?; + let attention_probs = attention_scores.softmax(candle::D::Minus1)?; let attention_probs = self.dropout.forward(&attention_probs)?; let context_layer = attention_probs.matmul(&value_layer)?; let context_layer = context_layer.transpose(1, 2)?.contiguous()?; - let context_layer = context_layer.flatten(Some(context_layer.rank() - 2), None)?; + let context_layer = context_layer.flatten_from(candle::D::Minus2)?; Ok(context_layer) } } |