diff options
Diffstat (limited to 'candle-examples/examples/llama/model.rs')
-rw-r--r-- | candle-examples/examples/llama/model.rs | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index 6ee4a585..0da3697f 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -1,8 +1,8 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{Embedding, VarBuilder}; +use serde::Deserialize; use std::collections::HashMap; use std::sync::{Arc, Mutex}; -use serde::Deserialize; use super::MAX_SEQ_LEN; @@ -17,9 +17,9 @@ pub struct LlamaConfig { pub rms_norm_eps: f64, } -impl LlamaConfig{ - pub fn into_config(&self, use_flash_attn: bool) -> Config{ - Config{ +impl LlamaConfig { + pub fn into_config(self, use_flash_attn: bool) -> Config { + Config { hidden_size: self.hidden_size, intermediate_size: self.intermediate_size, vocab_size: self.vocab_size, @@ -27,7 +27,7 @@ impl LlamaConfig{ num_attention_heads: self.num_attention_heads, num_key_value_heads: self.num_key_value_heads, rms_norm_eps: self.rms_norm_eps, - use_flash_attn + use_flash_attn, } } } |