summaryrefslogtreecommitdiff
path: root/candle-pyo3/e5.py
diff options
context:
space:
mode:
Diffstat (limited to 'candle-pyo3/e5.py')
-rw-r--r--candle-pyo3/e5.py104
1 files changed, 104 insertions, 0 deletions
diff --git a/candle-pyo3/e5.py b/candle-pyo3/e5.py
new file mode 100644
index 00000000..a0af0c56
--- /dev/null
+++ b/candle-pyo3/e5.py
@@ -0,0 +1,104 @@
+from candle.utils import load_safetensors, save_gguf, load_gguf
+from candle.models.bert import BertModel, Config
+import json
+from candle import Tensor
+from tqdm import tqdm
+from dataclasses import fields
+import os
+import time
+
+from huggingface_hub import hf_hub_download
+from transformers import BertTokenizer, AutoModel
+import torch
+
+if __name__ == "__main__":
+ model_name = "intfloat/e5-small-v2"
+ model_file = hf_hub_download(repo_id=model_name, filename="model.safetensors")
+ config_file = hf_hub_download(repo_id=model_name, filename="config.json")
+
+ tensors = load_safetensors(model_file)
+ config = Config()
+ with open(config_file, "r") as f:
+ raw_config = json.load(f)
+ for field in fields(config):
+ if field.name in raw_config:
+ setattr(config, field.name, raw_config[field.name])
+
+ # Load the model
+ model = BertModel(config)
+ model.load_state_dict(tensors)
+
+ hf_model = AutoModel.from_pretrained(model_name)
+ tokenizer = BertTokenizer.from_pretrained(model_name)
+
+ sentences = [
+ "The cat sits outside",
+ "A man is playing guitar",
+ "I love pasta",
+ "The new movie is awesome",
+ "The cat plays in the garden",
+ "A woman watches TV",
+ "The new movie is so great",
+ "Do you like pizza?",
+ ]
+
+ def average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor):
+ """Average the hidden states according to the attention mask"""
+ last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
+
+ tokenized = tokenizer(sentences, padding=True)
+ tokens = Tensor(tokenized["input_ids"])
+ token_type_ids = Tensor(tokenized["token_type_ids"])
+ encoder_out, _ = model.forward(tokens, token_type_ids)
+
+ hf_tokenized = tokenizer(sentences, padding=True, return_tensors="pt")
+ hf_result = hf_model(**hf_tokenized)["last_hidden_state"]
+
+ hf_pooled = average_pool(hf_result, hf_tokenized["attention_mask"])
+ candle_pooled = average_pool(torch.tensor(encoder_out.values()), hf_tokenized["attention_mask"])
+
+ loss = torch.nn.L1Loss()
+ error = loss(hf_pooled, candle_pooled).mean().item()
+ print(f"Mean error between torch-referenze and candle: {error}")
+
+ # Quantize all attention 'weights'
+ quantized_tensors = {}
+ for name, tensor in tqdm(tensors.items(), desc="Quantizing tensors to 5-Bit"):
+ if name.endswith("weight") and ("attention" in name or "intermediate" in name or "output" in name):
+ # check if the tensor is k-quantizable
+ if tensor.shape[-1] % 256 == 0:
+ new_tensor = tensor.quantize("q4k")
+ else:
+ new_tensor = tensor.quantize("q5_0")
+ quantized_tensors[name] = new_tensor
+ else:
+ quantized_tensors[name] = tensor.quantize("q8_0")
+
+ print(f"Saving quantized tensors")
+ # Remove all None values from the config
+ config_to_save = {k: v for k, v in config.__dict__.items() if v is not None}
+ # Save the model
+ quantized_model_file = "e5_small.gguf"
+ save_gguf(quantized_model_file, quantized_tensors, config_to_save)
+
+ file_size_mb = os.path.getsize(model_file) / 1024 / 1024
+ file_size_mb_compressed = os.path.getsize(quantized_model_file) / 1024 / 1024
+ print(f"Compressed model from {file_size_mb:.2f} MB to {file_size_mb_compressed:.2f} MB")
+ # Load the model from the gguf
+ tensors, raw_config = load_gguf(quantized_model_file)
+ config = Config()
+ for field in fields(config):
+ if field.name in raw_config:
+ setattr(config, field.name, raw_config[field.name])
+ model = BertModel(config)
+ # "embeddings.position_ids" is missing in the gguf as it is i64
+ model.load_state_dict(tensors, strict=False)
+
+ # Run the model again
+ encoder_out_2, pooled_output_2 = model.forward(tokens, token_type_ids)
+ encoder_out_2, pooled_output_2 = encoder_out_2.to_device("cpu"), pooled_output_2.to_device("cpu")
+
+ candle_pooled_2 = average_pool(torch.tensor(encoder_out_2.values()), hf_tokenized["attention_mask"])
+ error = loss(hf_pooled, candle_pooled_2).mean().item()
+ print(f"Mean error between torch-referenze and quantized-candle: {error}")