summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama2-c/main.rs
blob: 418218b6888ab6c7a7c1281795fc1705a5e1a6b4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
// https://github.com/karpathy/llama2.c

#[cfg(feature = "accelerate")]
extern crate accelerate_src;

#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

mod model;
mod training;
mod weights;
use clap::{Parser, Subcommand};

use anyhow::{Error as E, Result};
use byteorder::{LittleEndian, ReadBytesExt};
use candle::{IndexOp, Tensor};
use candle_transformers::generation::LogitsProcessor;
use std::io::Write;
use tokenizers::Tokenizer;

use model::{Config, Llama};
use weights::TransformerWeights;

#[derive(Parser, Debug, Clone)]
struct InferenceCmd {
    /// The temperature used to generate samples.
    #[arg(long)]
    temperature: Option<f64>,

    #[arg(long, default_value = "")]
    prompt: String,

    /// Config file in binary or safetensors format.
    #[arg(long)]
    config: Option<String>,

    #[arg(long, default_value = "karpathy/tinyllamas")]
    model_id: String,

    /// The model to be used when getting it from the hub. Possible
    /// values are 'stories15M.bin', 'stories42M.bin', see more at:
    /// https://huggingface.co/karpathy/tinyllamas/tree/main
    #[arg(long, default_value = "stories15M.bin")]
    which_model: String,
}

#[derive(Parser, Debug, Clone)]
struct EvaluationCmd {
    /// A directory with the pre-tokenized dataset in the format generated by the tinystories.py
    /// script from llama2.c https://github.com/karpathy/llama2.c
    #[arg(long)]
    pretokenized_dir: Option<String>,

    #[arg(long, default_value_t = 32)]
    batch_size: usize,

    /// Config file in binary format.
    #[arg(long)]
    config: Option<String>,

    #[arg(long, default_value = "karpathy/tinyllamas")]
    model_id: String,

    /// The model to be used when getting it from the hub. Possible
    /// values are 'stories15M.bin', 'stories42M.bin', see more at:
    /// https://huggingface.co/karpathy/tinyllamas/tree/main
    #[arg(long, default_value = "stories15M.bin")]
    which_model: String,
}

#[derive(Parser, Debug, Clone)]
pub struct TrainingCmd {
    /// A directory with the pre-tokenized dataset in the format generated by the tinystories.py
    /// script from llama2.c https://github.com/karpathy/llama2.c
    #[arg(long)]
    pretokenized_dir: String,

    #[arg(long, default_value_t = 32)]
    batch_size: usize,

    #[arg(long, default_value_t = 0.001)]
    learning_rate: f64,
}

#[derive(Subcommand, Debug, Clone)]
enum Task {
    Inference(InferenceCmd),
    Eval(EvaluationCmd),
    Train(TrainingCmd),
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
    /// The task to be performed, inference, training or evaluation.
    #[command(subcommand)]
    task: Option<Task>,

    /// Run on CPU rather than on GPU.
    #[arg(long)]
    cpu: bool,

    /// Tokenizer config file.
    #[arg(long)]
    tokenizer: Option<String>,
}

impl Args {
    fn tokenizer(&self) -> Result<Tokenizer> {
        let tokenizer_path = match &self.tokenizer {
            Some(config) => std::path::PathBuf::from(config),
            None => {
                let api = hf_hub::api::sync::Api::new()?;
                let api = api.model("hf-internal-testing/llama-tokenizer".to_string());
                api.get("tokenizer.json")?
            }
        };
        Tokenizer::from_file(tokenizer_path).map_err(E::msg)
    }
}

fn main() -> anyhow::Result<()> {
    let args = Args::parse();
    match &args.task {
        None => {
            let cmd = InferenceCmd {
                temperature: None,
                prompt: "".to_string(),
                config: None,
                model_id: "karpathy/tinyllamas".to_string(),
                which_model: "stories15M.bin".to_string(),
            };
            run_inference(&cmd, &args)?
        }
        Some(Task::Inference(cmd)) => run_inference(cmd, &args)?,
        Some(Task::Eval(cmd)) => run_eval(cmd, &args)?,
        Some(Task::Train(cmd)) => training::run(cmd, &args)?,
    }
    Ok(())
}

fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
    use std::io::BufRead;

    let config_path = match &args.config {
        Some(config) => std::path::PathBuf::from(config),
        None => {
            let api = hf_hub::api::sync::Api::new()?;
            println!("loading the model weights from {}", args.model_id);
            let api = api.model(args.model_id.clone());
            api.get(&args.which_model)?
        }
    };

    let tokenizer = common_args.tokenizer()?;

    let device = candle_examples::device(common_args.cpu)?;
    let mut file = std::fs::File::open(config_path)?;
    let config = Config::from_reader(&mut file)?;
    let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
    let vb = weights.var_builder(&config, &device)?;
    let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
    let model = Llama::load(vb, &cache, config)?;

    let tokens = match &args.pretokenized_dir {
        None => {
            let api = hf_hub::api::sync::Api::new()?;
            let model_id = "roneneldan/TinyStories"; // TODO: Make this configurable.
            println!("loading the evaluation dataset from {}", model_id);
            let api = api.dataset(model_id.to_string());
            let dataset_path = api.get("TinyStories-valid.txt")?;
            let file = std::fs::File::open(dataset_path)?;
            let file = std::io::BufReader::new(file);
            let mut tokens = vec![];
            for line in file.lines() {
                let line = line?.replace("<|endoftext|>", "<s>");
                let line = tokenizer.encode(line, false).map_err(E::msg)?;
                tokens.push(line.get_ids().to_vec())
            }
            tokens.concat()
        }
        Some(pretokenized_dir) => {
            // Use shard 0 for the test split, similar to llama2.c
            // https://github.com/karpathy/llama2.c/blob/ce05cc28cf1e3560b873bb21837638a434520a67/tinystories.py#L121
            let path = std::path::PathBuf::from(pretokenized_dir).join("data00.bin");
            let bytes = std::fs::read(path)?;
            // Tokens are encoded as u16.
            let mut tokens = vec![0u16; bytes.len() / 2];
            std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens)?;
            tokens.into_iter().map(|u| u as u32).collect::<Vec<u32>>()
        }
    };
    println!("dataset loaded and encoded: {} tokens", tokens.len());

    let seq_len = model.config.seq_len;
    let iter = (0..tokens.len()).step_by(seq_len).flat_map(|start_idx| {
        if start_idx + seq_len + 1 > tokens.len() {
            None
        } else {
            let tokens = &tokens[start_idx..start_idx + seq_len + 1];
            let inputs = Tensor::new(&tokens[..seq_len], &device);
            let targets = Tensor::new(&tokens[1..], &device);
            Some(inputs.and_then(|inputs| targets.map(|targets| (inputs, targets))))
        }
    });
    let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
    for inp_tgt in batch_iter {
        let (inp, tgt) = inp_tgt?;
        let logits = model.forward(&inp, 0)?;
        let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
        println!("{}", loss.to_vec0::<f32>()?);
    }
    Ok(())
}

fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
    let config_path = match &args.config {
        Some(config) => std::path::PathBuf::from(config),
        None => {
            let api = hf_hub::api::sync::Api::new()?;
            println!("loading the model weights from {}", args.model_id);
            let api = api.model(args.model_id.clone());
            api.get(&args.which_model)?
        }
    };

    let tokenizer = common_args.tokenizer()?;

    let device = candle_examples::device(common_args.cpu)?;

    let is_safetensors = config_path
        .extension()
        .map_or(false, |v| v == "safetensors");
    let (vb, config) = if is_safetensors {
        let config = Config::tiny();
        let tensors = candle::safetensors::load(config_path, &device)?;
        let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
        (vb, config)
    } else {
        let mut file = std::fs::File::open(config_path)?;
        let config = Config::from_reader(&mut file)?;
        println!("{config:?}");
        let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
        let vb = weights.var_builder(&config, &device)?;
        (vb, config)
    };
    let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
    let model = Llama::load(vb, &cache, config)?;

    println!("starting the inference loop");
    let mut logits_processor = LogitsProcessor::new(299792458, args.temperature);
    let mut index_pos = 0;

    print!("{}", args.prompt);
    let mut tokens = tokenizer
        .encode(args.prompt.clone(), true)
        .map_err(E::msg)?
        .get_ids()
        .to_vec();

    let start_gen = std::time::Instant::now();
    for index in 0.. {
        if tokens.len() >= model.config.seq_len {
            break;
        }
        let context_size = if index > 0 { 1 } else { tokens.len() };
        let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
        let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
        let logits = model.forward(&input, index_pos)?;
        let logits = logits.i((0, logits.dim(1)? - 1))?;
        index_pos += ctxt.len();

        let next_token = logits_processor.sample(&logits)?;
        tokens.push(next_token);
        // Extracting the last token as a string is complicated, here we just apply some simple
        // heuristics as it seems to work well enough for this example. See the following for more
        // details:
        // https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
        if let Some(text) = tokenizer.id_to_token(next_token) {
            let text = text.replace('▁', " ").replace("<0x0A>", "\n");
            print!("{text}");
            std::io::stdout().flush()?;
        }
    }
    let dt = start_gen.elapsed();
    println!(
        "\n{} tokens generated ({:.2} token/s)\n",
        tokens.len(),
        tokens.len() as f64 / dt.as_secs_f64(),
    );
    Ok(())
}