diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-28 12:13:41 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-28 12:13:41 +0100 |
commit | 68eab38de6e5cabf17159a5dcf45ec703fbea441 (patch) | |
tree | a6041c0f578667e319f0b6849fa6a6116b734358 /candle-examples/examples/bigcode/model.rs | |
parent | 54ccf944727483615e0bce1fa07499522ab9ca9c (diff) | |
download | candle-68eab38de6e5cabf17159a5dcf45ec703fbea441.tar.gz candle-68eab38de6e5cabf17159a5dcf45ec703fbea441.tar.bz2 candle-68eab38de6e5cabf17159a5dcf45ec703fbea441.zip |
Cuda fix for starcoder. (#266)
* Cuda fix for starcoder.
* Nicer output.
Diffstat (limited to 'candle-examples/examples/bigcode/model.rs')
-rw-r--r-- | candle-examples/examples/bigcode/model.rs | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/candle-examples/examples/bigcode/model.rs b/candle-examples/examples/bigcode/model.rs index 3b8033bb..3f68a5be 100644 --- a/candle-examples/examples/bigcode/model.rs +++ b/candle-examples/examples/bigcode/model.rs @@ -22,11 +22,11 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> { Ok(LayerNorm::new(weight, bias, eps)) } -fn make_causal_mask(t: usize) -> Result<Tensor> { +fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> { let mask: Vec<_> = (0..t) .flat_map(|i| (0..t).map(move |j| u32::from(j <= i))) .collect(); - let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?; + let mask = Tensor::from_slice(&mask, (t, t), device)?; Ok(mask) } @@ -327,7 +327,7 @@ impl GPTBigCode { .collect::<Result<Vec<_>>>()?; let ln_f = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb_t.pp("ln_f"))?; let lm_head = linear(hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))?; - let bias = make_causal_mask(cfg.max_position_embeddings)?; + let bias = make_causal_mask(cfg.max_position_embeddings, vb.device())?; Ok(Self { wte, wpe, |