summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/error.rs3
-rw-r--r--candle-core/src/npy.rs10
-rw-r--r--candle-core/src/utils.rs7
-rw-r--r--candle-examples/examples/llama/convert_checkpoint.py251
-rw-r--r--candle-examples/examples/llama/main.rs10
-rw-r--r--candle-nn/src/var_builder.rs83
6 files changed, 269 insertions, 95 deletions
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs
index caad3e1f..27fd11bb 100644
--- a/candle-core/src/error.rs
+++ b/candle-core/src/error.rs
@@ -139,6 +139,9 @@ pub enum Error {
rhs_stride: Vec<usize>,
mnk: (usize, usize, usize),
},
+
+ #[error("cannot find tensor {path}")]
+ CannotFindTensor { path: String },
}
pub type Result<T> = std::result::Result<T, Error>;
diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs
index c0608519..7cf6d381 100644
--- a/candle-core/src/npy.rs
+++ b/candle-core/src/npy.rs
@@ -1,10 +1,10 @@
-//! Numpy support for literals.
+//! Numpy support for tensors.
//!
//! The spec for the npy format can be found in
//! [npy-format](https://docs.scipy.org/doc/numpy-1.14.2/neps/npy-format.html).
-//! The functions from this module can be used to read literals from npy/npz files
-//! or write literals to these files. A npy file contains a single literal (unnamed)
-//! whereas a npz file can contain multiple named literals. npz files are also compressed.
+//! The functions from this module can be used to read tensors from npy/npz files
+//! or write tensors to these files. A npy file contains a single tensor (unnamed)
+//! whereas a npz file can contain multiple named tensors. npz files are also compressed.
//!
//! These two formats are easy to use in Python using the numpy library.
//!
@@ -232,7 +232,7 @@ impl Tensor {
}
}
- /// Reads a npy file and return the stored multi-dimensional array as a literal.
+ /// Reads a npy file and return the stored multi-dimensional array as a tensor.
pub fn read_npy<T: AsRef<Path>>(path: T) -> Result<Self> {
let mut reader = File::open(path.as_ref())?;
let header = read_header(&mut reader)?;
diff --git a/candle-core/src/utils.rs b/candle-core/src/utils.rs
index 4b1e941b..b5621e56 100644
--- a/candle-core/src/utils.rs
+++ b/candle-core/src/utils.rs
@@ -10,3 +10,10 @@ pub fn get_num_threads() -> usize {
Some(_) | None => num_cpus::get(),
}
}
+
+pub fn has_mkl() -> bool {
+ #[cfg(feature = "mkl")]
+ return true;
+ #[cfg(not(feature = "mkl"))]
+ return false;
+}
diff --git a/candle-examples/examples/llama/convert_checkpoint.py b/candle-examples/examples/llama/convert_checkpoint.py
index 245c167c..1b44a04a 100644
--- a/candle-examples/examples/llama/convert_checkpoint.py
+++ b/candle-examples/examples/llama/convert_checkpoint.py
@@ -1,68 +1,199 @@
-# Adapted from https://github.com/Lightning-AI/lit-llama/blob/main/scripts/convert_checkpoint.py
-import sys
+# Adapted from:
+# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py
+# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
+import argparse
+import gc
+import json
+import math
+import os
+import shutil
+import warnings
+
import torch
import numpy as np
-from typing import Dict
-from pathlib import Path
-
-def tr(v):
- return np.ascontiguousarray(np.transpose(v))
-
-def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float32) -> Dict[str, torch.Tensor]:
- print("start conv")
-
- def get_and_remove(key, transpose=False):
- v = state_dict[key].to(dtype).numpy()
- if transpose:
- v = tr(v)
- del state_dict[key]
- return v
-
- converted = {}
- converted["transformer.wte.weight"] = get_and_remove("tok_embeddings.weight")
- converted["lm_head.weight"] = get_and_remove("output.weight", transpose=True)
- converted["transformer.ln_f.scale"] = get_and_remove("norm.weight")
-
- for layer_idx in sorted(set([k.split(".")[1] for k in state_dict if k.startswith("layers")])):
- print(layer_idx)
-
- # attention
- # the wq, wk, wv from the FB model are stacked in our model as c_attn
- converted[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = tr(np.concatenate(
- (
- get_and_remove(f"layers.{layer_idx}.attention.wq.weight"),
- get_and_remove(f"layers.{layer_idx}.attention.wk.weight"),
- get_and_remove(f"layers.{layer_idx}.attention.wv.weight"),
+
+"""
+Sample usage:
+
+```
+python src/transformers/models/llama/convert_llama_weights_to_hf.py \
+ --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
+```
+"""
+
+INTERMEDIATE_SIZE_MAP = {
+ "7B": 11008,
+ "13B": 13824,
+ "30B": 17920,
+ "65B": 22016,
+}
+NUM_SHARDS = {
+ "7B": 1,
+ "13B": 2,
+ "30B": 4,
+ "65B": 8,
+}
+
+
+def compute_intermediate_size(n):
+ return int(math.ceil(n * 8 / 3) + 255) // 256 * 256
+
+
+def read_json(path):
+ with open(path, "r") as f:
+ return json.load(f)
+
+
+def write_json(text, path):
+ with open(path, "w") as f:
+ json.dump(text, f)
+
+
+def write_model(model_path, input_base_path, model_size):
+ os.makedirs(model_path, exist_ok=True)
+
+ params = read_json(os.path.join(input_base_path, "params.json"))
+ num_shards = NUM_SHARDS[model_size]
+ n_layers = params["n_layers"]
+ n_heads = params["n_heads"]
+ n_heads_per_shard = n_heads // num_shards
+ dim = params["dim"]
+ dims_per_head = dim // n_heads
+ base = 10000.0
+ inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
+
+ # permute for sliced rotary
+ def permute(w):
+ return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
+
+ print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
+ # Load weights
+ if model_size == "7B":
+ # Not sharded
+ # (The sharded implementation would also work, but this is simpler.)
+ loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
+ else:
+ # Sharded
+ loaded = [
+ torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
+ for i in range(num_shards)
+ ]
+ param_count = 0
+ all_dicts = {}
+ for layer_i in range(n_layers):
+ if model_size == "7B":
+ # Unsharded
+ state_dict = {
+ f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
+ loaded[f"layers.{layer_i}.attention.wq.weight"]
+ ),
+ f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
+ loaded[f"layers.{layer_i}.attention.wk.weight"]
+ ),
+ f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
+ f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
+ f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
+ f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
+ f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
+ f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"],
+ f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"],
+ }
+ else:
+ # Sharded
+ # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share
+ # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is
+ # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned.
+
+ state_dict = {
+ f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][
+ f"layers.{layer_i}.attention_norm.weight"
+ ].clone(),
+ f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][
+ f"layers.{layer_i}.ffn_norm.weight"
+ ].clone(),
+ }
+ state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
+ torch.cat(
+ [
+ loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
+ for i in range(num_shards)
+ ],
+ dim=0,
+ ).reshape(dim, dim)
+ )
+ state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
+ torch.cat(
+ [
+ loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(n_heads_per_shard, dims_per_head, dim)
+ for i in range(num_shards)
+ ],
+ dim=0,
+ ).reshape(dim, dim)
)
- ))
- converted[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = tr(get_and_remove(
- f"layers.{layer_idx}.attention.wo.weight"
- ))
- # mlp
- converted[f"transformer.h.{layer_idx}.mlp.c_fc1.weight"] = get_and_remove(
- f"layers.{layer_idx}.feed_forward.w1.weight", transpose=True,
+ state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
+ [
+ loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(n_heads_per_shard, dims_per_head, dim)
+ for i in range(num_shards)
+ ],
+ dim=0,
+ ).reshape(dim, dim)
+
+ state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
+ [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
)
- converted[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = get_and_remove(
- f"layers.{layer_idx}.feed_forward.w2.weight", transpose=True,
+ state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
+ [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
)
- converted[f"transformer.h.{layer_idx}.mlp.c_fc2.weight"] = get_and_remove(
- f"layers.{layer_idx}.feed_forward.w3.weight", transpose=True,
+ state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
+ [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
)
- # rms norm
- converted[f"transformer.h.{layer_idx}.rms_1.scale"] = get_and_remove(f"layers.{layer_idx}.attention_norm.weight")
- converted[f"transformer.h.{layer_idx}.rms_2.scale"] = get_and_remove(f"layers.{layer_idx}.ffn_norm.weight")
- return converted
-
-def convert_weights(llama_ckpt, *, output_npz: Path = Path("llama.npz"), dtype: str = "float32") -> None:
- dt = getattr(torch, dtype, None)
- if not isinstance(dt, torch.dtype):
- raise ValueError(f"{dtype} is not a valid dtype.")
- checkpoint = torch.load(llama_ckpt, map_location="cpu")
- converted = convert_state_dict(checkpoint, dtype=dt)
- del checkpoint
- np.savez(output_npz, **converted)
+ state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
+ [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
+ )
+
+ state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
+ all_dicts |= state_dict
+
+ if model_size == "7B":
+ # Unsharded
+ state_dict = {
+ "model.embed_tokens.weight": loaded["tok_embeddings.weight"],
+ "model.norm.weight": loaded["norm.weight"],
+ "lm_head.weight": loaded["output.weight"],
+ }
+ else:
+ state_dict = {
+ "model.norm.weight": loaded[0]["norm.weight"],
+ "model.embed_tokens.weight": torch.cat(
+ [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1
+ ),
+ "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
+ }
+ all_dicts |= state_dict
+ all_dicts = {k: v.numpy() for k, v in all_dicts.items()}
+ np.savez(os.path.join(model_path, "llama.npz"), **all_dicts)
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--input_dir",
+ help="Location of LLaMA weights, which contains tokenizer.model and model folders",
+ )
+ parser.add_argument(
+ "--model_size",
+ choices=["7B", "13B", "30B", "65B"],
+ )
+ parser.add_argument(
+ "--output_dir",
+ help="Location to write HF model and tokenizer",
+ )
+ args = parser.parse_args()
+ write_model(
+ model_path=args.output_dir,
+ input_base_path=os.path.join(args.input_dir, args.model_size),
+ model_size=args.model_size,
+ )
+
if __name__ == "__main__":
- if len(sys.argv) != 2:
- raise ValueError(f"usage: convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth")
- convert_weights(sys.argv[1])
+ main()
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index 75cea7ff..6ac4458e 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -144,8 +144,14 @@ fn main() -> Result<()> {
let config = Config::config_7b();
let cache = model::Cache::new(!args.no_kv_cache, &config, &device);
let (llama, tokenizer_filename) = match args.npy {
- Some(_) => {
- todo!("fix numpy handling if we continue supporting it")
+ Some(filename) => {
+ let tensors = Tensor::read_npz(filename)?
+ .into_iter()
+ .map(|(n, t)| Ok((n, t.to_dtype(DTYPE)?)))
+ .collect::<Result<std::collections::HashMap<String, Tensor>>>()?;
+ let vb = VarBuilder::from_tensors(tensors, DTYPE, &device);
+ let tokenizer = std::path::PathBuf::from("llama-tokenizer.json");
+ (Llama::load(vb, &cache, &config)?, tokenizer)
}
None => {
let api = Api::new()?;
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index d71b5822..6d79bddd 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -1,15 +1,20 @@
-use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
+use candle::{safetensors::SafeTensors, DType, Device, Error, Shape, Tensor};
use std::collections::HashMap;
use std::sync::Arc;
-struct SafeTensorWithRouting<'a> {
- routing: HashMap<String, usize>,
- safetensors: Vec<SafeTensors<'a>>,
+// TODO: Maybe we would want the storage to be generic, e.g. with Box<dyn> to avoid too many
+// generics.
+enum Tensors<'a> {
+ SafeTensorWithRouting {
+ routing: HashMap<String, usize>,
+ safetensors: Vec<SafeTensors<'a>>,
+ },
+ TensorMap(HashMap<String, Tensor>),
+ Zeros,
}
struct TensorData<'a> {
- // TODO: Make this part generic, probably via some Box<dyn> to avoid too much generics.
- safetensors: Option<SafeTensorWithRouting<'a>>,
+ tensors: Tensors<'a>,
pub dtype: DType,
pub device: Device,
}
@@ -22,12 +27,12 @@ impl<'a> TensorData<'a> {
routing.insert(k.to_string(), index);
}
}
- let safetensors = SafeTensorWithRouting {
+ let tensors = Tensors::SafeTensorWithRouting {
routing,
safetensors,
};
Self {
- safetensors: Some(safetensors),
+ tensors,
device: device.clone(),
dtype,
}
@@ -35,7 +40,15 @@ impl<'a> TensorData<'a> {
fn zeros(dtype: DType, device: &Device) -> Self {
Self {
- safetensors: None,
+ tensors: Tensors::Zeros,
+ device: device.clone(),
+ dtype,
+ }
+ }
+
+ fn from_tensors(tensors: HashMap<String, Tensor>, dtype: DType, device: &Device) -> Self {
+ Self {
+ tensors: Tensors::TensorMap(tensors),
device: device.clone(),
dtype,
}
@@ -67,6 +80,14 @@ impl<'a> VarBuilder<'a> {
}
}
+ pub fn from_tensors(ts: HashMap<String, Tensor>, dtype: DType, device: &Device) -> Self {
+ let data = TensorData::from_tensors(ts, dtype, device);
+ Self {
+ data: Arc::new(data),
+ path: vec![],
+ }
+ }
+
pub fn push_prefix(&self, s: &str) -> Self {
let mut path = self.path.clone();
path.push(s.to_string());
@@ -94,31 +115,37 @@ impl<'a> VarBuilder<'a> {
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> candle::Result<Tensor> {
let data = self.data.as_ref();
let s: Shape = s.into();
- match &self.data.safetensors {
- None => Tensor::zeros(s, data.dtype, &data.device),
- Some(SafeTensorWithRouting {
+ let path = if self.path.is_empty() {
+ tensor_name.to_string()
+ } else {
+ [&self.path.join("."), tensor_name].join(".")
+ };
+ let tensor = match &self.data.tensors {
+ Tensors::Zeros => Tensor::zeros(&s, data.dtype, &data.device)?.contiguous()?,
+ Tensors::TensorMap(ts) => ts
+ .get(&path)
+ .ok_or_else(|| Error::CannotFindTensor {
+ path: path.to_string(),
+ })?
+ .clone(),
+ Tensors::SafeTensorWithRouting {
routing,
safetensors,
- }) => {
- let path = if self.path.is_empty() {
- tensor_name.to_string()
- } else {
- [&self.path.join("."), tensor_name].join(".")
- };
+ } => {
// Unwrap or 0 just to let the proper error flow.
let index = routing.get(&path).unwrap_or(&0);
- let tensor = safetensors[*index]
+ safetensors[*index]
.tensor(&path, &data.device)?
- .to_dtype(data.dtype)?;
- if *tensor.shape() != s {
- Err(candle::Error::UnexpectedShape {
- msg: format!("shape mismatch for {path}"),
- expected: s,
- got: tensor.shape().clone(),
- })?
- }
- Ok(tensor)
+ .to_dtype(data.dtype)?
}
+ };
+ if tensor.shape() != &s {
+ Err(candle::Error::UnexpectedShape {
+ msg: format!("shape mismatch for {path}"),
+ expected: s,
+ got: tensor.shape().clone(),
+ })?
}
+ Ok(tensor)
}
}