summaryrefslogtreecommitdiff
path: root/candle-examples/examples/bigcode/model.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-28 12:13:41 +0100
committerGitHub <noreply@github.com>2023-07-28 12:13:41 +0100
commit68eab38de6e5cabf17159a5dcf45ec703fbea441 (patch)
treea6041c0f578667e319f0b6849fa6a6116b734358 /candle-examples/examples/bigcode/model.rs
parent54ccf944727483615e0bce1fa07499522ab9ca9c (diff)
downloadcandle-68eab38de6e5cabf17159a5dcf45ec703fbea441.tar.gz
candle-68eab38de6e5cabf17159a5dcf45ec703fbea441.tar.bz2
candle-68eab38de6e5cabf17159a5dcf45ec703fbea441.zip
Cuda fix for starcoder. (#266)
* Cuda fix for starcoder. * Nicer output.
Diffstat (limited to 'candle-examples/examples/bigcode/model.rs')
-rw-r--r--candle-examples/examples/bigcode/model.rs6
1 files changed, 3 insertions, 3 deletions
diff --git a/candle-examples/examples/bigcode/model.rs b/candle-examples/examples/bigcode/model.rs
index 3b8033bb..3f68a5be 100644
--- a/candle-examples/examples/bigcode/model.rs
+++ b/candle-examples/examples/bigcode/model.rs
@@ -22,11 +22,11 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
Ok(LayerNorm::new(weight, bias, eps))
}
-fn make_causal_mask(t: usize) -> Result<Tensor> {
+fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u32::from(j <= i)))
.collect();
- let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
+ let mask = Tensor::from_slice(&mask, (t, t), device)?;
Ok(mask)
}
@@ -327,7 +327,7 @@ impl GPTBigCode {
.collect::<Result<Vec<_>>>()?;
let ln_f = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb_t.pp("ln_f"))?;
let lm_head = linear(hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))?;
- let bias = make_causal_mask(cfg.max_position_embeddings)?;
+ let bias = make_causal_mask(cfg.max_position_embeddings, vb.device())?;
Ok(Self {
wte,
wpe,