summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama/convert_checkpoint.py
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2023-07-03 11:30:58 +0100
committerlaurent <laurent.mazare@gmail.com>2023-07-03 11:30:58 +0100
commitfdb1acd2ffa961a0e8e83ae5de30d19213419a6c (patch)
tree2b7a5028b80c2da08c77628f5d2b211391de60ef /candle-examples/examples/llama/convert_checkpoint.py
parentd0d530dfdce04d5fb656b10b4eb1bfd26dea37e8 (diff)
downloadcandle-fdb1acd2ffa961a0e8e83ae5de30d19213419a6c.tar.gz
candle-fdb1acd2ffa961a0e8e83ae5de30d19213419a6c.tar.bz2
candle-fdb1acd2ffa961a0e8e83ae5de30d19213419a6c.zip
Move llama in a cargo-examples directory.
Diffstat (limited to 'candle-examples/examples/llama/convert_checkpoint.py')
-rw-r--r--candle-examples/examples/llama/convert_checkpoint.py68
1 files changed, 68 insertions, 0 deletions
diff --git a/candle-examples/examples/llama/convert_checkpoint.py b/candle-examples/examples/llama/convert_checkpoint.py
new file mode 100644
index 00000000..245c167c
--- /dev/null
+++ b/candle-examples/examples/llama/convert_checkpoint.py
@@ -0,0 +1,68 @@
+# Adapted from https://github.com/Lightning-AI/lit-llama/blob/main/scripts/convert_checkpoint.py
+import sys
+import torch
+import numpy as np
+from typing import Dict
+from pathlib import Path
+
+def tr(v):
+ return np.ascontiguousarray(np.transpose(v))
+
+def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float32) -> Dict[str, torch.Tensor]:
+ print("start conv")
+
+ def get_and_remove(key, transpose=False):
+ v = state_dict[key].to(dtype).numpy()
+ if transpose:
+ v = tr(v)
+ del state_dict[key]
+ return v
+
+ converted = {}
+ converted["transformer.wte.weight"] = get_and_remove("tok_embeddings.weight")
+ converted["lm_head.weight"] = get_and_remove("output.weight", transpose=True)
+ converted["transformer.ln_f.scale"] = get_and_remove("norm.weight")
+
+ for layer_idx in sorted(set([k.split(".")[1] for k in state_dict if k.startswith("layers")])):
+ print(layer_idx)
+
+ # attention
+ # the wq, wk, wv from the FB model are stacked in our model as c_attn
+ converted[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = tr(np.concatenate(
+ (
+ get_and_remove(f"layers.{layer_idx}.attention.wq.weight"),
+ get_and_remove(f"layers.{layer_idx}.attention.wk.weight"),
+ get_and_remove(f"layers.{layer_idx}.attention.wv.weight"),
+ )
+ ))
+ converted[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = tr(get_and_remove(
+ f"layers.{layer_idx}.attention.wo.weight"
+ ))
+ # mlp
+ converted[f"transformer.h.{layer_idx}.mlp.c_fc1.weight"] = get_and_remove(
+ f"layers.{layer_idx}.feed_forward.w1.weight", transpose=True,
+ )
+ converted[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = get_and_remove(
+ f"layers.{layer_idx}.feed_forward.w2.weight", transpose=True,
+ )
+ converted[f"transformer.h.{layer_idx}.mlp.c_fc2.weight"] = get_and_remove(
+ f"layers.{layer_idx}.feed_forward.w3.weight", transpose=True,
+ )
+ # rms norm
+ converted[f"transformer.h.{layer_idx}.rms_1.scale"] = get_and_remove(f"layers.{layer_idx}.attention_norm.weight")
+ converted[f"transformer.h.{layer_idx}.rms_2.scale"] = get_and_remove(f"layers.{layer_idx}.ffn_norm.weight")
+ return converted
+
+def convert_weights(llama_ckpt, *, output_npz: Path = Path("llama.npz"), dtype: str = "float32") -> None:
+ dt = getattr(torch, dtype, None)
+ if not isinstance(dt, torch.dtype):
+ raise ValueError(f"{dtype} is not a valid dtype.")
+ checkpoint = torch.load(llama_ckpt, map_location="cpu")
+ converted = convert_state_dict(checkpoint, dtype=dt)
+ del checkpoint
+ np.savez(output_npz, **converted)
+
+if __name__ == "__main__":
+ if len(sys.argv) != 2:
+ raise ValueError(f"usage: convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth")
+ convert_weights(sys.argv[1])