summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama/model.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/llama/model.rs')
-rw-r--r--candle-examples/examples/llama/model.rs10
1 files changed, 5 insertions, 5 deletions
diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs
index 6ee4a585..0da3697f 100644
--- a/candle-examples/examples/llama/model.rs
+++ b/candle-examples/examples/llama/model.rs
@@ -1,8 +1,8 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, VarBuilder};
+use serde::Deserialize;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
-use serde::Deserialize;
use super::MAX_SEQ_LEN;
@@ -17,9 +17,9 @@ pub struct LlamaConfig {
pub rms_norm_eps: f64,
}
-impl LlamaConfig{
- pub fn into_config(&self, use_flash_attn: bool) -> Config{
- Config{
+impl LlamaConfig {
+ pub fn into_config(self, use_flash_attn: bool) -> Config {
+ Config {
hidden_size: self.hidden_size,
intermediate_size: self.intermediate_size,
vocab_size: self.vocab_size,
@@ -27,7 +27,7 @@ impl LlamaConfig{
num_attention_heads: self.num_attention_heads,
num_key_value_heads: self.num_key_value_heads,
rms_norm_eps: self.rms_norm_eps,
- use_flash_attn
+ use_flash_attn,
}
}
}