1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
|
// https://github.com/karpathy/llama2.c
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use candle_transformers::models::llama2_c as model;
use candle_transformers::models::llama2_c_weights as weights;
use candle_transformers::models::quantized_llama2_c as qmodel;
mod training;
use clap::{Parser, Subcommand};
use anyhow::{Error as E, Result};
use byteorder::{LittleEndian, ReadBytesExt};
use candle::{IndexOp, Tensor};
use candle_transformers::generation::LogitsProcessor;
use std::io::Write;
use tokenizers::Tokenizer;
use model::{Cache, Config, Llama};
use qmodel::QLlama;
use weights::TransformerWeights;
#[derive(Parser, Debug, Clone)]
struct InferenceCmd {
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
#[arg(long, default_value = "")]
prompt: String,
/// Config file in binary or safetensors format.
#[arg(long)]
config: Option<String>,
#[arg(long, default_value = "karpathy/tinyllamas")]
model_id: String,
/// The model to be used when getting it from the hub. Possible
/// values are 'stories15M.bin', 'stories42M.bin', see more at:
/// https://huggingface.co/karpathy/tinyllamas/tree/main
#[arg(long, default_value = "stories15M.bin")]
which_model: String,
}
#[derive(Parser, Debug, Clone)]
struct EvaluationCmd {
/// A directory with the pre-tokenized dataset in the format generated by the tinystories.py
/// script from llama2.c https://github.com/karpathy/llama2.c
#[arg(long)]
pretokenized_dir: Option<String>,
#[arg(long, default_value_t = 32)]
batch_size: usize,
/// Config file in binary format.
#[arg(long)]
config: Option<String>,
#[arg(long, default_value = "karpathy/tinyllamas")]
model_id: String,
/// The model to be used when getting it from the hub. Possible
/// values are 'stories15M.bin', 'stories42M.bin', see more at:
/// https://huggingface.co/karpathy/tinyllamas/tree/main
#[arg(long, default_value = "stories15M.bin")]
which_model: String,
}
#[derive(Parser, Debug, Clone)]
pub struct TrainingCmd {
/// A directory with the pre-tokenized dataset in the format generated by the tinystories.py
/// script from llama2.c https://github.com/karpathy/llama2.c
#[arg(long)]
pretokenized_dir: String,
#[arg(long, default_value_t = 32)]
batch_size: usize,
#[arg(long, default_value_t = 0.001)]
learning_rate: f64,
}
#[derive(Subcommand, Debug, Clone)]
enum Task {
Inference(InferenceCmd),
Eval(EvaluationCmd),
Train(TrainingCmd),
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
/// The task to be performed, inference, training or evaluation.
#[command(subcommand)]
task: Option<Task>,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Tokenizer config file.
#[arg(long)]
tokenizer: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
impl Args {
fn tokenizer(&self) -> Result<Tokenizer> {
let tokenizer_path = match &self.tokenizer {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("hf-internal-testing/llama-tokenizer".to_string());
api.get("tokenizer.json")?
}
};
Tokenizer::from_file(tokenizer_path).map_err(E::msg)
}
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
match &args.task {
None => {
let cmd = InferenceCmd {
temperature: None,
top_p: None,
prompt: "".to_string(),
config: None,
model_id: "karpathy/tinyllamas".to_string(),
which_model: "stories15M.bin".to_string(),
};
run_inference(&cmd, &args)?
}
Some(Task::Inference(cmd)) => run_inference(cmd, &args)?,
Some(Task::Eval(cmd)) => run_eval(cmd, &args)?,
Some(Task::Train(cmd)) => training::run(cmd, &args)?,
}
Ok(())
}
enum Model {
Llama(Llama),
QLlama(QLlama),
}
impl Model {
fn forward(&self, xs: &Tensor, pos: usize, cache: &mut Cache) -> anyhow::Result<Tensor> {
match self {
Self::Llama(l) => Ok(l.forward(xs, pos, cache)?),
Self::QLlama(l) => Ok(l.forward(xs, pos, cache)?),
}
}
}
fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
use std::io::BufRead;
let config_path = match &args.config {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
println!("loading the model weights from {}", args.model_id);
let api = api.model(args.model_id.clone());
api.get(&args.which_model)?
}
};
let tokenizer = common_args.tokenizer()?;
let device = candle_examples::device(common_args.cpu)?;
let mut file = std::fs::File::open(config_path)?;
let config = Config::from_reader(&mut file)?;
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?;
let mut cache = Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, config)?;
let tokens = match &args.pretokenized_dir {
None => {
let api = hf_hub::api::sync::Api::new()?;
let model_id = "roneneldan/TinyStories"; // TODO: Make this configurable.
println!("loading the evaluation dataset from {}", model_id);
let api = api.dataset(model_id.to_string());
let dataset_path = api.get("TinyStories-valid.txt")?;
let file = std::fs::File::open(dataset_path)?;
let file = std::io::BufReader::new(file);
let mut tokens = vec![];
for line in file.lines() {
let line = line?.replace("<|endoftext|>", "<s>");
let line = tokenizer.encode(line, false).map_err(E::msg)?;
tokens.push(line.get_ids().to_vec())
}
tokens.concat()
}
Some(pretokenized_dir) => {
// Use shard 0 for the test split, similar to llama2.c
// https://github.com/karpathy/llama2.c/blob/ce05cc28cf1e3560b873bb21837638a434520a67/tinystories.py#L121
let path = std::path::PathBuf::from(pretokenized_dir).join("data00.bin");
let bytes = std::fs::read(path)?;
// Tokens are encoded as u16.
let mut tokens = vec![0u16; bytes.len() / 2];
std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens)?;
tokens.into_iter().map(|u| u as u32).collect::<Vec<u32>>()
}
};
println!("dataset loaded and encoded: {} tokens", tokens.len());
let seq_len = model.config.seq_len;
let iter = (0..tokens.len()).step_by(seq_len).flat_map(|start_idx| {
if start_idx + seq_len + 1 > tokens.len() {
None
} else {
let tokens = &tokens[start_idx..start_idx + seq_len + 1];
let inputs = Tensor::new(&tokens[..seq_len], &device);
let targets = Tensor::new(&tokens[1..], &device);
Some(inputs.and_then(|inputs| targets.map(|targets| (inputs, targets))))
}
});
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
for inp_tgt in batch_iter {
let (inp, tgt) = inp_tgt?;
let logits = model.forward(&inp, 0, &mut cache)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
println!("{}", loss.to_vec0::<f32>()?);
}
Ok(())
}
fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let config_path = match &args.config {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
println!("loading the model weights from {}", args.model_id);
let api = api.model(args.model_id.clone());
api.get(&args.which_model)?
}
};
let tokenizer = common_args.tokenizer()?;
let device = candle_examples::device(common_args.cpu)?;
let is_gguf = config_path.extension().map_or(false, |v| v == "gguf");
let is_safetensors = config_path
.extension()
.map_or(false, |v| v == "safetensors");
let (model, config, mut cache) = if is_gguf {
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?;
let (_vocab_size, dim) = vb
.get_no_shape("model.embed_tokens.weight")?
.shape()
.dims2()?;
let config = match dim {
64 => Config::tiny_260k(),
288 => Config::tiny_15m(),
512 => Config::tiny_42m(),
768 => Config::tiny_110m(),
_ => anyhow::bail!("no config for dim {dim}"),
};
let freq_cis_real = vb
.get(
(config.seq_len, config.head_size() / 2),
"rot.freq_cis_real",
)?
.dequantize(&device)?;
let freq_cis_imag = vb
.get(
(config.seq_len, config.head_size() / 2),
"rot.freq_cis_imag",
)?
.dequantize(&device)?;
let fake_vb = candle_nn::VarBuilder::from_tensors(
[
("freq_cis_real".to_string(), freq_cis_real),
("freq_cis_imag".to_string(), freq_cis_imag),
]
.into_iter()
.collect(),
candle::DType::F32,
&device,
);
let cache = model::Cache::new(true, &config, fake_vb)?;
let model = Model::QLlama(QLlama::load(vb, config.clone())?);
(model, config, cache)
} else if is_safetensors {
let config = Config::tiny_15m();
let tensors = candle::safetensors::load(config_path, &device)?;
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Model::Llama(Llama::load(vb, config.clone())?);
(model, config, cache)
} else {
let mut file = std::fs::File::open(config_path)?;
let config = Config::from_reader(&mut file)?;
println!("{config:?}");
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?;
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Model::Llama(Llama::load(vb, config.clone())?);
(model, config, cache)
};
println!("starting the inference loop");
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p);
let mut index_pos = 0;
print!("{}", args.prompt);
let mut tokens = tokenizer
.encode(args.prompt.clone(), true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
let start_gen = std::time::Instant::now();
for index in 0.. {
if tokens.len() >= config.seq_len {
break;
}
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = model.forward(&input, index_pos, &mut cache)?;
let logits = logits.i((0, logits.dim(1)? - 1))?;
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
logits
} else {
let start_at = tokens.len().saturating_sub(common_args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
common_args.repeat_penalty,
&tokens[start_at..],
)?
};
index_pos += ctxt.len();
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
let dt = start_gen.elapsed();
println!(
"\n{} tokens generated ({:.2} token/s)\n",
tokens.len(),
tokens.len() as f64 / dt.as_secs_f64(),
);
Ok(())
}
|