diff options
Diffstat (limited to 'candle-pyo3/e5.py')
-rw-r--r-- | candle-pyo3/e5.py | 104 |
1 files changed, 104 insertions, 0 deletions
diff --git a/candle-pyo3/e5.py b/candle-pyo3/e5.py new file mode 100644 index 00000000..a0af0c56 --- /dev/null +++ b/candle-pyo3/e5.py @@ -0,0 +1,104 @@ +from candle.utils import load_safetensors, save_gguf, load_gguf +from candle.models.bert import BertModel, Config +import json +from candle import Tensor +from tqdm import tqdm +from dataclasses import fields +import os +import time + +from huggingface_hub import hf_hub_download +from transformers import BertTokenizer, AutoModel +import torch + +if __name__ == "__main__": + model_name = "intfloat/e5-small-v2" + model_file = hf_hub_download(repo_id=model_name, filename="model.safetensors") + config_file = hf_hub_download(repo_id=model_name, filename="config.json") + + tensors = load_safetensors(model_file) + config = Config() + with open(config_file, "r") as f: + raw_config = json.load(f) + for field in fields(config): + if field.name in raw_config: + setattr(config, field.name, raw_config[field.name]) + + # Load the model + model = BertModel(config) + model.load_state_dict(tensors) + + hf_model = AutoModel.from_pretrained(model_name) + tokenizer = BertTokenizer.from_pretrained(model_name) + + sentences = [ + "The cat sits outside", + "A man is playing guitar", + "I love pasta", + "The new movie is awesome", + "The cat plays in the garden", + "A woman watches TV", + "The new movie is so great", + "Do you like pizza?", + ] + + def average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor): + """Average the hidden states according to the attention mask""" + last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) + return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + + tokenized = tokenizer(sentences, padding=True) + tokens = Tensor(tokenized["input_ids"]) + token_type_ids = Tensor(tokenized["token_type_ids"]) + encoder_out, _ = model.forward(tokens, token_type_ids) + + hf_tokenized = tokenizer(sentences, padding=True, return_tensors="pt") + hf_result = hf_model(**hf_tokenized)["last_hidden_state"] + + hf_pooled = average_pool(hf_result, hf_tokenized["attention_mask"]) + candle_pooled = average_pool(torch.tensor(encoder_out.values()), hf_tokenized["attention_mask"]) + + loss = torch.nn.L1Loss() + error = loss(hf_pooled, candle_pooled).mean().item() + print(f"Mean error between torch-referenze and candle: {error}") + + # Quantize all attention 'weights' + quantized_tensors = {} + for name, tensor in tqdm(tensors.items(), desc="Quantizing tensors to 5-Bit"): + if name.endswith("weight") and ("attention" in name or "intermediate" in name or "output" in name): + # check if the tensor is k-quantizable + if tensor.shape[-1] % 256 == 0: + new_tensor = tensor.quantize("q4k") + else: + new_tensor = tensor.quantize("q5_0") + quantized_tensors[name] = new_tensor + else: + quantized_tensors[name] = tensor.quantize("q8_0") + + print(f"Saving quantized tensors") + # Remove all None values from the config + config_to_save = {k: v for k, v in config.__dict__.items() if v is not None} + # Save the model + quantized_model_file = "e5_small.gguf" + save_gguf(quantized_model_file, quantized_tensors, config_to_save) + + file_size_mb = os.path.getsize(model_file) / 1024 / 1024 + file_size_mb_compressed = os.path.getsize(quantized_model_file) / 1024 / 1024 + print(f"Compressed model from {file_size_mb:.2f} MB to {file_size_mb_compressed:.2f} MB") + # Load the model from the gguf + tensors, raw_config = load_gguf(quantized_model_file) + config = Config() + for field in fields(config): + if field.name in raw_config: + setattr(config, field.name, raw_config[field.name]) + model = BertModel(config) + # "embeddings.position_ids" is missing in the gguf as it is i64 + model.load_state_dict(tensors, strict=False) + + # Run the model again + encoder_out_2, pooled_output_2 = model.forward(tokens, token_type_ids) + encoder_out_2, pooled_output_2 = encoder_out_2.to_device("cpu"), pooled_output_2.to_device("cpu") + + candle_pooled_2 = average_pool(torch.tensor(encoder_out_2.values()), hf_tokenized["attention_mask"]) + error = loss(hf_pooled, candle_pooled_2).mean().item() + print(f"Mean error between torch-referenze and quantized-candle: {error}") |