diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-11 19:32:10 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-11 19:32:10 +0100 |
commit | 37cad858698e519435c916421cc97b4f6b7fe53e (patch) | |
tree | d9a7ddb65a25e53ed684e91b33d4c27eea3dc0d5 /candle-examples/examples/llama/convert_checkpoint.py | |
parent | 760f1d70551a761a815e0a9576c8fecb6bde6020 (diff) | |
download | candle-37cad858698e519435c916421cc97b4f6b7fe53e.tar.gz candle-37cad858698e519435c916421cc97b4f6b7fe53e.tar.bz2 candle-37cad858698e519435c916421cc97b4f6b7fe53e.zip |
Resurrect the llama npy support. (#140)
Diffstat (limited to 'candle-examples/examples/llama/convert_checkpoint.py')
-rw-r--r-- | candle-examples/examples/llama/convert_checkpoint.py | 251 |
1 files changed, 191 insertions, 60 deletions
diff --git a/candle-examples/examples/llama/convert_checkpoint.py b/candle-examples/examples/llama/convert_checkpoint.py index 245c167c..1b44a04a 100644 --- a/candle-examples/examples/llama/convert_checkpoint.py +++ b/candle-examples/examples/llama/convert_checkpoint.py @@ -1,68 +1,199 @@ -# Adapted from https://github.com/Lightning-AI/lit-llama/blob/main/scripts/convert_checkpoint.py -import sys +# Adapted from: +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py +# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved. +import argparse +import gc +import json +import math +import os +import shutil +import warnings + import torch import numpy as np -from typing import Dict -from pathlib import Path - -def tr(v): - return np.ascontiguousarray(np.transpose(v)) - -def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float32) -> Dict[str, torch.Tensor]: - print("start conv") - - def get_and_remove(key, transpose=False): - v = state_dict[key].to(dtype).numpy() - if transpose: - v = tr(v) - del state_dict[key] - return v - - converted = {} - converted["transformer.wte.weight"] = get_and_remove("tok_embeddings.weight") - converted["lm_head.weight"] = get_and_remove("output.weight", transpose=True) - converted["transformer.ln_f.scale"] = get_and_remove("norm.weight") - - for layer_idx in sorted(set([k.split(".")[1] for k in state_dict if k.startswith("layers")])): - print(layer_idx) - - # attention - # the wq, wk, wv from the FB model are stacked in our model as c_attn - converted[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = tr(np.concatenate( - ( - get_and_remove(f"layers.{layer_idx}.attention.wq.weight"), - get_and_remove(f"layers.{layer_idx}.attention.wk.weight"), - get_and_remove(f"layers.{layer_idx}.attention.wv.weight"), + +""" +Sample usage: + +``` +python src/transformers/models/llama/convert_llama_weights_to_hf.py \ + --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path +``` +""" + +INTERMEDIATE_SIZE_MAP = { + "7B": 11008, + "13B": 13824, + "30B": 17920, + "65B": 22016, +} +NUM_SHARDS = { + "7B": 1, + "13B": 2, + "30B": 4, + "65B": 8, +} + + +def compute_intermediate_size(n): + return int(math.ceil(n * 8 / 3) + 255) // 256 * 256 + + +def read_json(path): + with open(path, "r") as f: + return json.load(f) + + +def write_json(text, path): + with open(path, "w") as f: + json.dump(text, f) + + +def write_model(model_path, input_base_path, model_size): + os.makedirs(model_path, exist_ok=True) + + params = read_json(os.path.join(input_base_path, "params.json")) + num_shards = NUM_SHARDS[model_size] + n_layers = params["n_layers"] + n_heads = params["n_heads"] + n_heads_per_shard = n_heads // num_shards + dim = params["dim"] + dims_per_head = dim // n_heads + base = 10000.0 + inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) + + # permute for sliced rotary + def permute(w): + return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim) + + print(f"Fetching all parameters from the checkpoint at {input_base_path}.") + # Load weights + if model_size == "7B": + # Not sharded + # (The sharded implementation would also work, but this is simpler.) + loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu") + else: + # Sharded + loaded = [ + torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu") + for i in range(num_shards) + ] + param_count = 0 + all_dicts = {} + for layer_i in range(n_layers): + if model_size == "7B": + # Unsharded + state_dict = { + f"model.layers.{layer_i}.self_attn.q_proj.weight": permute( + loaded[f"layers.{layer_i}.attention.wq.weight"] + ), + f"model.layers.{layer_i}.self_attn.k_proj.weight": permute( + loaded[f"layers.{layer_i}.attention.wk.weight"] + ), + f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"], + f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"], + f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"], + f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"], + f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"], + f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"], + f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"], + } + else: + # Sharded + # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share + # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is + # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned. + + state_dict = { + f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][ + f"layers.{layer_i}.attention_norm.weight" + ].clone(), + f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][ + f"layers.{layer_i}.ffn_norm.weight" + ].clone(), + } + state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute( + torch.cat( + [ + loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim) + for i in range(num_shards) + ], + dim=0, + ).reshape(dim, dim) + ) + state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute( + torch.cat( + [ + loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(n_heads_per_shard, dims_per_head, dim) + for i in range(num_shards) + ], + dim=0, + ).reshape(dim, dim) ) - )) - converted[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = tr(get_and_remove( - f"layers.{layer_idx}.attention.wo.weight" - )) - # mlp - converted[f"transformer.h.{layer_idx}.mlp.c_fc1.weight"] = get_and_remove( - f"layers.{layer_idx}.feed_forward.w1.weight", transpose=True, + state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( + [ + loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(n_heads_per_shard, dims_per_head, dim) + for i in range(num_shards) + ], + dim=0, + ).reshape(dim, dim) + + state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1 ) - converted[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = get_and_remove( - f"layers.{layer_idx}.feed_forward.w2.weight", transpose=True, + state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0 ) - converted[f"transformer.h.{layer_idx}.mlp.c_fc2.weight"] = get_and_remove( - f"layers.{layer_idx}.feed_forward.w3.weight", transpose=True, + state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1 ) - # rms norm - converted[f"transformer.h.{layer_idx}.rms_1.scale"] = get_and_remove(f"layers.{layer_idx}.attention_norm.weight") - converted[f"transformer.h.{layer_idx}.rms_2.scale"] = get_and_remove(f"layers.{layer_idx}.ffn_norm.weight") - return converted - -def convert_weights(llama_ckpt, *, output_npz: Path = Path("llama.npz"), dtype: str = "float32") -> None: - dt = getattr(torch, dtype, None) - if not isinstance(dt, torch.dtype): - raise ValueError(f"{dtype} is not a valid dtype.") - checkpoint = torch.load(llama_ckpt, map_location="cpu") - converted = convert_state_dict(checkpoint, dtype=dt) - del checkpoint - np.savez(output_npz, **converted) + state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0 + ) + + state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq + all_dicts |= state_dict + + if model_size == "7B": + # Unsharded + state_dict = { + "model.embed_tokens.weight": loaded["tok_embeddings.weight"], + "model.norm.weight": loaded["norm.weight"], + "lm_head.weight": loaded["output.weight"], + } + else: + state_dict = { + "model.norm.weight": loaded[0]["norm.weight"], + "model.embed_tokens.weight": torch.cat( + [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1 + ), + "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0), + } + all_dicts |= state_dict + all_dicts = {k: v.numpy() for k, v in all_dicts.items()} + np.savez(os.path.join(model_path, "llama.npz"), **all_dicts) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_dir", + help="Location of LLaMA weights, which contains tokenizer.model and model folders", + ) + parser.add_argument( + "--model_size", + choices=["7B", "13B", "30B", "65B"], + ) + parser.add_argument( + "--output_dir", + help="Location to write HF model and tokenizer", + ) + args = parser.parse_args() + write_model( + model_path=args.output_dir, + input_base_path=os.path.join(args.input_dir, args.model_size), + model_size=args.model_size, + ) + if __name__ == "__main__": - if len(sys.argv) != 2: - raise ValueError(f"usage: convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth") - convert_weights(sys.argv[1]) + main() |