summaryrefslogtreecommitdiff
path: root/candle-examples/examples/bigcode/model.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/bigcode/model.rs')
-rw-r--r--candle-examples/examples/bigcode/model.rs12
1 files changed, 1 insertions, 11 deletions
diff --git a/candle-examples/examples/bigcode/model.rs b/candle-examples/examples/bigcode/model.rs
index 3f68a5be..12993e2d 100644
--- a/candle-examples/examples/bigcode/model.rs
+++ b/candle-examples/examples/bigcode/model.rs
@@ -30,16 +30,6 @@ fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> {
Ok(mask)
}
-// TODO: Use a numerically stable implementation by default.
-fn softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
- let d = d.to_index(xs.shape(), "log-softmax")?;
- let max = xs.max_keepdim(d)?;
- let diff = xs.broadcast_sub(&max)?;
- let num = diff.exp()?;
- let den = num.sum_keepdim(d)?;
- num.broadcast_div(&den)
-}
-
#[derive(Debug)]
pub struct Config {
pub vocab_size: usize,
@@ -192,7 +182,7 @@ impl Attention {
let mask_value =
Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?;
let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?;
- let attn_weights = softmax(&attn_weights, D::Minus1)?;
+ let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
let value = value.contiguous()?;
let attn_output = if self.multi_query {
attn_weights