diff options
-rw-r--r-- | candle-core/examples/llama/main.rs | 18 |
1 files changed, 12 insertions, 6 deletions
diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs index 8feb7fb0..e936d6b3 100644 --- a/candle-core/examples/llama/main.rs +++ b/candle-core/examples/llama/main.rs @@ -152,7 +152,7 @@ impl Linear { } fn forward(&self, x: &Tensor) -> Result<Tensor> { - let x = x.matmul(&self.weight.to_dtype(DType::F32)?.t()?)?; + let x = x.matmul(&self.weight.t()?)?; Ok(x) } } @@ -167,8 +167,9 @@ impl RmsNorm { } fn forward(&self, x: &Tensor) -> Result<Tensor> { + let x = x.to_dtype(DType::F32)?; let (seq_len, hidden_size) = x.shape().r2()?; - let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?; + let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?; let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?; let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?; let size = self.scale.shape().r1()?; @@ -176,7 +177,9 @@ impl RmsNorm { .scale .to_dtype(DType::F32)? .broadcast_as((seq_len, size))?; - Ok((scale * x_normed)?) + let x = (scale * x_normed)?; + let x = x.to_dtype(DType::F16)?; + Ok(x) } } @@ -285,6 +288,7 @@ impl CausalSelfAttention { fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> { let (t, c) = x.shape().r2()?; let qkv = self.c_attn.forward(x)?; + let qkv = qkv.to_dtype(DType::F32)?; let n_embd = c; let q = qkv.narrow(1, 0, n_embd)?; let k = qkv.narrow(1, n_embd, n_embd)?; @@ -303,6 +307,7 @@ impl CausalSelfAttention { // Convert to contiguous as matmul doesn't support strided vs for now. let y = att.matmul(&v.contiguous()?)?; let y = y.transpose(0, 1)?.reshape(&[t, c])?; + let y = y.to_dtype(DType::F16)?; let y = self.c_proj.forward(&y)?; Ok(y) } @@ -352,14 +357,14 @@ impl Llama { fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> { // TODO: Support for mini-batches? (i.e. r2) let t = x.shape().r1()?; - let x = self.wte.forward(x)?; - let mut x = x.to_dtype(DType::F32)?; + let mut x = self.wte.forward(x)?; for block in self.blocks.iter() { x = block.forward(&x, freqs_cis)?; } let x = self.ln_f.forward(&x)?; let x = x.narrow(0, t - 1, 1)?; let logits = self.lm_head.forward(&x)?; + let logits = logits.to_dtype(DType::F32)?; let (b, vocab_size) = logits.shape().r2()?; assert_eq!(b, 1); Ok(logits.reshape(vocab_size)?) @@ -420,7 +425,6 @@ async fn main() -> Result<()> { } else { Device::new_cuda(0)? }; - let api = Api::new()?; let config = Config::config_7b(); let cache = Cache::new(&device); let start = std::time::Instant::now(); @@ -431,7 +435,9 @@ async fn main() -> Result<()> { std::path::Path::new("llama-tokenizer.json").to_path_buf(), ) } else { + let api = Api::new()?; let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model); + println!("building the model"); let tokenizer_filename = api.get(&repo, "tokenizer.json").await?; let mut filenames = vec![]; for rfilename in [ |