summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/llama/main.rs')
-rw-r--r--candle-examples/examples/llama/main.rs17
1 files changed, 17 insertions, 0 deletions
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index f3cf17bc..b2c4e55a 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -9,6 +9,9 @@
// In order to convert the llama weights to a .npz file, run:
// python examples/llama/convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
@@ -111,6 +114,10 @@ struct Args {
#[arg(long)]
use_f32: bool,
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+
#[arg(long)]
model_id: Option<String>,
@@ -123,8 +130,18 @@ struct Args {
fn main() -> Result<()> {
use tokenizers::Tokenizer;
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
let args = Args::parse();
+ let _guard = if args.tracing {
+ println!("tracing...");
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
let device = candle_examples::device(args.cpu)?;
let config = if args.v1 {