| | |
| |
|
| | import numpy as np |
| | import torch |
| | import onnxruntime as ort |
| | from transformers import AutoTokenizer |
| | from huggingface_hub import hf_hub_download |
| |
|
| |
|
| | class QuantizedSentenceEncoder: |
| | def __init__(self, model_name: str = "skatzR/USER-BGE-M3-ONNX-INT8", device: str = None): |
| | """ |
| | Universal loader for quantized ONNX model. |
| | :param model_name: HuggingFace repo id |
| | :param device: "cpu" or "cuda". If None, selected automatically. |
| | """ |
| |
|
| | self.model_name = model_name |
| |
|
| | |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| |
|
| | |
| | if device is None: |
| | self.device = "cuda" if ort.get_device() == "GPU" else "cpu" |
| | else: |
| | self.device = device |
| |
|
| | |
| | model_path = hf_hub_download(repo_id=model_name, filename="model_quantized.onnx") |
| |
|
| | |
| | providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if self.device == "cuda" else ["CPUExecutionProvider"] |
| | self.session = ort.InferenceSession(model_path, providers=providers) |
| |
|
| | |
| | self.input_names = [inp.name for inp in self.session.get_inputs()] |
| | self.output_name = self.session.get_outputs()[0].name |
| |
|
| | def _mean_pooling(self, model_output, attention_mask): |
| | """ |
| | Mean pooling (as in the original) |
| | model_output: np.array (batch_size, seq_len, hidden_size) |
| | attention_mask: torch.Tensor (batch_size, seq_len) |
| | """ |
| | token_embeddings = model_output |
| | input_mask_expanded = np.expand_dims(attention_mask.numpy(), -1) |
| |
|
| | |
| | sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1) |
| | sum_mask = np.clip(input_mask_expanded.sum(axis=1), a_min=1e-9, a_max=None) |
| | return sum_embeddings / sum_mask |
| |
|
| | def encode(self, texts, normalize: bool = True, batch_size: int = 32): |
| | """ |
| | Get sentence embeddings |
| | :param texts: str или list[str] |
| | :param normalize: whether to apply L2 normalization |
| | :param batch_size: batch size |
| | :return: np.array (num_texts, hidden_size) |
| | """ |
| | if isinstance(texts, str): |
| | texts = [texts] |
| |
|
| | all_embeddings = [] |
| |
|
| | for start in range(0, len(texts), batch_size): |
| | batch_texts = texts[start:start + batch_size] |
| | enc = self.tokenizer( |
| | batch_texts, |
| | padding=True, |
| | truncation=True, |
| | return_tensors="pt", |
| | max_length=512 |
| | ) |
| |
|
| | |
| | ort_inputs = {k: v.cpu().numpy() for k, v in enc.items() if k in self.input_names} |
| |
|
| | |
| | ort_outs = self.session.run([self.output_name], ort_inputs) |
| | token_embeddings = ort_outs[0] |
| |
|
| | |
| | embeddings = self._mean_pooling(token_embeddings, enc["attention_mask"]) |
| |
|
| | |
| | if normalize: |
| | embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) |
| |
|
| | all_embeddings.append(embeddings) |
| |
|
| | return np.vstack(all_embeddings) |
| |
|