raxtemur commited on
Commit
a8acb5a
·
verified ·
1 Parent(s): 5831667

Initial upload (weights + code + README)

Browse files
README.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SONAR-LLM (300M)
2
+
3
+ We present SONAR-LLM, a decoder-only transformer that "thinks" in the same continuous SONAR embedding space, yet is supervised through token-level cross-entropy propagated via the frozen SONAR decoder. This hybrid objective retains the semantic abstraction of LCM while eliminating its diffusion sampler and restoring a likelihood-based training signal. Across model sizes from 39M to 1.3B parameters, SONAR-LLM attains competitive generation quality.
4
+
5
+ Original repository: `https://github.com/FusionBrainLab/SONAR-LLM`
6
+ Paper: `https://arxiv.org/abs/2508.05305`
7
+
8
+ Minimal bundle with SONAR-LLM 300M checkpoint and code.
9
+
10
+ ## Install
11
+ - Use a fresh venv/conda
12
+ - Install SONAR from the official repo: `https://github.com/facebookresearch/SONAR`
13
+ - Ensure PyTorch and transformers are installed
14
+ - (Optional) Download NLTK punkt: `python -c "import nltk; nltk.download('punkt')"`
15
+
16
+ ## Usage
17
+ ```python
18
+ from sonarllm_model import SONARLLMGenerator, SONARLLMGenerationConfig
19
+
20
+ gen = SONARLLMGenerator.load_from_checkpoint(".")
21
+ eos_emb = gen.t2vec.predict(["End of sequence."], source_lang="eng_Latn").to(gen.device)
22
+ cfg = SONARLLMGenerationConfig(temperature=0.2, latent_top_p=0.9, decoder_beam_size=1)
23
+ print(gen.generate("Once upon a time", eos_emb, cfg))
24
+ ```
25
+
26
+ ## Files
27
+ - `pytorch_model.bin`
28
+ - `config.json`
29
+ - `sonarllm_model/`
30
+
31
+ ## Notes
32
+ - SONAR install guide: `https://github.com/facebookresearch/SONAR`
33
+ - Tokenizer name is taken from `config.json`.
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "pretrained_model_name_or_path": "meta-llama/Llama-3.2-1B",
3
+ "llama_config": {
4
+ "hidden_size": 1024,
5
+ "intermediate_size": 4096,
6
+ "num_hidden_layers": 10,
7
+ "num_attention_heads": 16,
8
+ "hidden_act": "silu",
9
+ "max_position_embeddings": 131072,
10
+ "initializer_range": 0.02,
11
+ "rms_norm_eps": 1e-06,
12
+ "use_cache": true,
13
+ "pretraining_tp": 1,
14
+ "tie_word_embeddings": true,
15
+ "rope_theta": 500000.0,
16
+ "rope_scaling": {
17
+ "factor": 32.0,
18
+ "high_freq_factor": 4.0,
19
+ "low_freq_factor": 1.0,
20
+ "original_max_position_embeddings": 8192,
21
+ "rope_type": "llama3"
22
+ },
23
+ "attention_bias": false,
24
+ "attention_dropout": 0.0,
25
+ "mlp_bias": false,
26
+ "head_dim": 64
27
+ },
28
+ "embed_dim": 1024
29
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99730076a888df2aa898d4655963268e11048e06021854d9a2944e46ae7ee21f
3
+ size 1204940238
sonarllm_model/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .sonar_llm_model import SONARLLMGenerator, SONARLLMGenerationConfig
2
+
3
+
4
+
5
+
sonarllm_model/embedding_to_text_with_scores.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Iterable, List, Optional
3
+
4
+ import torch
5
+
6
+ from fairseq2.generation import (
7
+ BeamSearchSeq2SeqGenerator,
8
+ Sampler,
9
+ SamplingSeq2SeqGenerator,
10
+ Seq2SeqGenerator,
11
+ SequenceToTextConverter,
12
+ )
13
+
14
+ from sonar.inference_pipelines.utils import add_progress_bar
15
+ from sonar.inference_pipelines.text import (
16
+ EmbeddingToTextModelPipeline as _BaseEmbeddingToTextModelPipeline,
17
+ )
18
+ from fairseq2.data.data_pipeline import read_sequence
19
+
20
+
21
+ class EmbeddingToTextModelPipeline(_BaseEmbeddingToTextModelPipeline):
22
+ """Drop-in replacement that can also return sentence log-probabilities via return_scores.
23
+
24
+ - When return_scores=False (default), behaves exactly like the base pipeline and returns List[str].
25
+ - When return_scores=True, returns a tuple (List[str], List[float]) where each float is the
26
+ hypothesis score from fairseq2 (sum of token log-probabilities if normalize_scores=False,
27
+ otherwise length-normalized per fairseq2 semantics).
28
+ """
29
+
30
+ @torch.inference_mode()
31
+ def predict(
32
+ self,
33
+ inputs: torch.Tensor,
34
+ target_lang: str,
35
+ batch_size: int = 5,
36
+ progress_bar: bool = False,
37
+ sampler: Optional[Sampler] = None,
38
+ return_scores: bool = False,
39
+ **generator_kwargs,
40
+ ):
41
+ if sampler is not None:
42
+ generator: Seq2SeqGenerator = SamplingSeq2SeqGenerator(
43
+ self.model, sampler, **generator_kwargs
44
+ )
45
+ else:
46
+ generator = BeamSearchSeq2SeqGenerator(self.model, **generator_kwargs)
47
+
48
+ converter = SequenceToTextConverter(
49
+ generator,
50
+ self.tokenizer,
51
+ task="translation",
52
+ target_lang=target_lang,
53
+ )
54
+
55
+ def _do_translate(src_tensors: List[torch.Tensor]):
56
+ texts, gen_out = converter.batch_convert(
57
+ torch.stack(src_tensors).to(self.device), None
58
+ )
59
+ if return_scores:
60
+ scores: List[float] = []
61
+ for hyps in gen_out.hypotheses:
62
+ if len(hyps) == 0 or hyps[0].score is None:
63
+ scores.append(0.0)
64
+ else:
65
+ scores.append(float(hyps[0].score))
66
+ return texts, scores
67
+ return texts
68
+
69
+ pipeline: Iterable = (
70
+ read_sequence(list(inputs))
71
+ .bucket(batch_size)
72
+ .map(_do_translate)
73
+ .and_return()
74
+ )
75
+
76
+ if progress_bar:
77
+ pipeline = add_progress_bar(pipeline, inputs=inputs, batch_size=batch_size)
78
+
79
+ results: List = list(iter(pipeline))
80
+
81
+ if not return_scores:
82
+ # results is List[List[str]] → flatten
83
+ return [text for batch_texts in results for text in batch_texts]
84
+
85
+ # results is List[Tuple[List[str], List[float]]] → flatten both
86
+ all_texts: List[str] = []
87
+ all_scores: List[float] = []
88
+ for batch in results:
89
+ batch_texts, batch_scores = batch
90
+ all_texts.extend(batch_texts)
91
+ all_scores.extend(batch_scores)
92
+ return all_texts, all_scores
93
+
94
+
sonarllm_model/sonar_llm_model.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Tuple
4
+
5
+
6
+ import nltk
7
+ from nltk.tokenize import sent_tokenize
8
+ nltk.download("punkt", quiet=True)
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from .embedding_to_text_with_scores import EmbeddingToTextModelPipeline
15
+ from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
16
+
17
+ class Projector(nn.Module):
18
+ def __init__(self, in_dim: int, out_dim: int):
19
+ super().__init__()
20
+ self.linear = nn.Linear(in_dim, out_dim)
21
+
22
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
23
+ return self.linear(x)
24
+
25
+ @dataclass
26
+ class SONARLLMGenerationConfig:
27
+ # Outer sentence-level beam
28
+ sentence_beam_size: int = 4
29
+ latent_samples_per_step: int = 4 # M latent variants per active beam state
30
+
31
+ # Token-level decoder params
32
+ decoder_beam_size: int = 5 # default in fairseq2
33
+ decoder_temperature: float = 1.0 # default in fairseq2
34
+ normalize_sentence_scores: bool = True # False → sum of token log-probs
35
+ decoder_max_len: int = 256
36
+
37
+ # Latent sampling
38
+ temperature: float = 0.4
39
+ latent_top_p: Optional[float] = None # 0<p<=1 or None for Gaussian
40
+ temperature_mode: str = "relative" # "absolute" | "relative"
41
+
42
+ # Repetition control in latent space
43
+ repetition_penalty: float = 0.0
44
+ repetition_memory: int = 0
45
+
46
+ # Termination
47
+ max_sentences: int = 32
48
+ eos_threshold: float = 0.98
49
+
50
+
51
+ class SONARLLMGenerator(torch.nn.Module):
52
+ """Sentence-level beam over latent reversed embeddings using SONAR decoder.
53
+
54
+ For each step:
55
+ - Run LLaMA on the sentence embedding history to get final hidden.
56
+ - Sample multiple latent directions (temperature/latent_top_p, with repetition penalty).
57
+ - Project to `reversed_emb` and decode text via SONAR decoder.
58
+ - Score each candidate using decoder sentence logprob (+ optional shaping).
59
+ - Keep top `sentence_beam_size` states and continue until EOS or max sentences.
60
+
61
+ This class does NOT modify existing project files and can be used standalone.
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ llama_model: nn.Module,
67
+ forward_proj: nn.Module,
68
+ reverse_proj: nn.Module,
69
+ sonar_decoder: EmbeddingToTextModelPipeline,
70
+ t2vec_model: TextToEmbeddingModelPipeline,
71
+ device: torch.device,
72
+ ) -> None:
73
+ super().__init__()
74
+ self.llama_model = llama_model.eval()
75
+ self.forward_proj = forward_proj.eval()
76
+ self.reverse_proj = reverse_proj.eval()
77
+ self.sonar_decoder = sonar_decoder.eval()
78
+ self.t2vec = t2vec_model.eval()
79
+ self.device = device
80
+
81
+ @torch.no_grad()
82
+ def generate(self, prefix_text: str, eos_emb: torch.Tensor, cfg: Optional[SONARLLMGenerationConfig] = None) -> str:
83
+ # Normalize and attach config to the instance for helper use
84
+ if cfg is None:
85
+ cfg = SONARLLMGenerationConfig()
86
+ self._cfg = cfg
87
+ sents = sent_tokenize(prefix_text)
88
+ if len(sents) == 0:
89
+ sents = [prefix_text.strip()]
90
+
91
+ # Initialize prefix embeddings
92
+ emb_seq = self.t2vec.predict(sents, source_lang="eng_Latn").to(self.device)
93
+
94
+ # Beam state tuple: (sentences, embeddings_seq, cumulative_score, recent_dirs)
95
+ beams: List[Tuple[List[str], torch.Tensor, float, List[torch.Tensor]]] = [
96
+ (sents[:], emb_seq, 0.0, [])
97
+ ]
98
+
99
+ steps = 0
100
+ while steps < self._cfg.max_sentences:
101
+ steps += 1
102
+ candidates: List[Tuple[List[str], torch.Tensor, float, List[torch.Tensor]]] = []
103
+
104
+ for (hist_sents, hist_emb, score, recent_dirs) in beams:
105
+ candidates.extend(
106
+ self._expand_beam_state(hist_sents, hist_emb, score, recent_dirs, eos_emb)
107
+ )
108
+
109
+ # Keep top-k beams
110
+ if len(candidates) == 0:
111
+ break
112
+ candidates.sort(key=lambda b: b[2], reverse=True)
113
+ beams = candidates[: int(self._cfg.sentence_beam_size)]
114
+
115
+ # If all beams look ended by EOS threshold, stop early
116
+ if self._all_close_to_eos(beams, eos_emb):
117
+ break
118
+
119
+ best = max(beams, key=lambda b: b[2])
120
+ return self._join_sentences(best[0])
121
+
122
+ # --- internals ---
123
+
124
+ @torch.no_grad()
125
+ def _forward_hidden(self, emb_seq: torch.Tensor) -> torch.Tensor:
126
+ proj = self.forward_proj(emb_seq.unsqueeze(0)) if emb_seq.ndim == 2 else self.forward_proj(emb_seq)
127
+ out = self.llama_model(inputs_embeds=proj, output_hidden_states=True)
128
+ hidden = out.hidden_states[-1]
129
+ return hidden[0, -1, :]
130
+
131
+ def _join_sentences(self, sents: List[str]) -> str:
132
+ return " ".join(sents)
133
+
134
+ def _update_recent_dirs(
135
+ self, recent: List[torch.Tensor], u: torch.Tensor, memory_cap: int
136
+ ) -> List[torch.Tensor]:
137
+ if memory_cap <= 0:
138
+ return recent
139
+ if not torch.isfinite(u).all():
140
+ return recent
141
+ new_recent = recent + [u.detach().to("cpu")]
142
+ if len(new_recent) > int(memory_cap):
143
+ new_recent = new_recent[-int(memory_cap) :]
144
+ return new_recent
145
+
146
+ def _sample_noise_direction(
147
+ self, final_hidden: torch.Tensor, recent_dirs: List[torch.Tensor]
148
+ ) -> torch.Tensor:
149
+ g = torch.randn_like(final_hidden)
150
+ if (
151
+ self._cfg.repetition_penalty is not None
152
+ and float(self._cfg.repetition_penalty) != 1.0
153
+ and self._cfg.repetition_memory > 0
154
+ and len(recent_dirs) > 0
155
+ ):
156
+ g = self._apply_repetition_penalty_to_direction(
157
+ g, float(self._cfg.repetition_penalty), int(self._cfg.repetition_memory), recent_dirs
158
+ )
159
+ return g / (g.norm(p=2) + 1e-12)
160
+
161
+ def _sample_noise(
162
+ self, final_hidden: torch.Tensor, dir_unit: torch.Tensor
163
+ ) -> torch.Tensor:
164
+ t = float(self._cfg.temperature)
165
+ if t <= 0.0:
166
+ return torch.zeros_like(final_hidden)
167
+
168
+ if self._cfg.temperature_mode not in ("absolute", "relative"):
169
+ raise ValueError(f"Unsupported temperature_mode: {self._cfg.temperature_mode}")
170
+
171
+ if self._cfg.temperature_mode == "absolute":
172
+ sigma = torch.tensor(t, device=final_hidden.device, dtype=final_hidden.dtype)
173
+ else:
174
+ rms = torch.sqrt(torch.mean(final_hidden.to(torch.float32) ** 2))
175
+ rms = torch.clamp(rms, min=1e-12).to(dtype=final_hidden.dtype, device=final_hidden.device)
176
+ sigma = rms * t
177
+
178
+ top_p = self._cfg.latent_top_p
179
+ if top_p is None:
180
+ top_p = 1.0
181
+ return self._sample_truncated_normal_like(final_hidden, float(top_p), sigma, dir_unit)
182
+
183
+ def _sample_truncated_normal_like(
184
+ self, base_vector: torch.Tensor, top_p: float, sigma: torch.Tensor, dir_unit: torch.Tensor
185
+ ) -> torch.Tensor:
186
+ # Wilson–Hilferty approximation for ChiSquare quantiles
187
+ dim = base_vector.numel()
188
+ device = base_vector.device
189
+ u = torch.rand((), device=device, dtype=torch.float32)
190
+ p = torch.clamp(u * float(top_p), min=1e-12, max=1.0 - 1e-12)
191
+ k = torch.tensor(float(dim), device=device, dtype=torch.float32)
192
+ z = torch.sqrt(torch.tensor(2.0, device=device, dtype=torch.float32)) * torch.special.erfinv(2.0 * p - 1.0)
193
+ term = 1.0 - 2.0 / (9.0 * k) + z * torch.sqrt(2.0 / (9.0 * k))
194
+ term = torch.clamp(term, min=1e-12)
195
+ s = k * (term ** 3)
196
+ r = torch.sqrt(torch.clamp(s, min=1e-12)).to(dtype=base_vector.dtype)
197
+ return dir_unit * (r * sigma)
198
+
199
+ def _expand_beam_state(
200
+ self,
201
+ hist_sents: List[str],
202
+ hist_emb: torch.Tensor,
203
+ score: float,
204
+ recent_dirs: List[torch.Tensor],
205
+ eos_emb: torch.Tensor,
206
+ ) -> List[Tuple[List[str], torch.Tensor, float, List[torch.Tensor]]]:
207
+ """Expand one beam state into candidate next states.
208
+
209
+ Returns a list of (new_hist_sents, new_hist_emb, new_score, new_recent_dirs).
210
+ """
211
+ final_hidden = self._forward_hidden(hist_emb)
212
+ out: List[Tuple[List[str], torch.Tensor, float, List[torch.Tensor]]] = []
213
+
214
+ for _ in range(max(1, int(self._cfg.latent_samples_per_step))):
215
+ dir_unit = self._sample_noise_direction(final_hidden, recent_dirs)
216
+ noise = self._sample_noise(final_hidden, dir_unit)
217
+ h_perturbed = final_hidden + noise
218
+ z = self.reverse_proj(h_perturbed.unsqueeze(0))
219
+
220
+ texts, scores = self.sonar_decoder.predict(
221
+ z,
222
+ target_lang="eng_Latn",
223
+ beam_size=int(self._cfg.decoder_beam_size),
224
+ normalize_scores=bool(self._cfg.normalize_sentence_scores),
225
+ max_seq_len=self._cfg.decoder_max_len,
226
+ temperature=float(self._cfg.decoder_temperature),
227
+ return_scores=True,
228
+ )
229
+ text = texts[0]
230
+ sent_logprob = float(scores[0])
231
+
232
+ z_re = self.t2vec.predict([text], source_lang="eng_Latn").to(self.device)
233
+
234
+ cand_score = score + sent_logprob
235
+ new_recent = self._update_recent_dirs(recent_dirs, dir_unit, self._cfg.repetition_memory)
236
+
237
+ new_hist_sents = hist_sents + [text]
238
+ new_hist_emb = torch.cat([hist_emb, z_re], dim=0)
239
+
240
+ out.append((new_hist_sents, new_hist_emb, cand_score, new_recent))
241
+
242
+ return out
243
+
244
+ def _apply_repetition_penalty_to_direction(
245
+ self, g: torch.Tensor, penalty: float, memory_cap: int, recent_dirs: List[torch.Tensor]
246
+ ) -> torch.Tensor:
247
+ """Mean-shift (A+) repetition penalty in latent direction space.
248
+
249
+ - penalty is clamped to [0, 1].
250
+ - penalty = 0 → no shift (q = 0.5).
251
+ - penalty = 1 → maximum shift (q ≈ q_min).
252
+ Mapping: q = 0.5^(1-penalty) * q_min^(penalty), beta = Phi^{-1}(1 - q),
253
+ and we set g' = g - beta * b_unit, where b_unit is the normalized average of recent directions.
254
+ """
255
+ if memory_cap <= 0 or len(recent_dirs) == 0:
256
+ return g
257
+
258
+ # Aggregate and normalize recent directions
259
+ B = torch.stack(
260
+ [u.to(device=g.device, dtype=g.dtype) for u in recent_dirs[-int(memory_cap):]], dim=0
261
+ )
262
+ b = B.mean(dim=0)
263
+ bn = b.norm(p=2)
264
+ if not torch.isfinite(bn) or bn <= 1e-12:
265
+ return g
266
+ b_unit = b / bn
267
+
268
+ # Clamp and map penalty → beta via q
269
+ rp = float(penalty)
270
+ if rp < 0.0:
271
+ rp = 0.0
272
+ if rp > 1.0:
273
+ rp = 1.0
274
+ q_min = 1e-12
275
+ log_q = (1.0 - rp) * torch.log(torch.tensor(0.5, device=g.device, dtype=torch.float32))
276
+ log_q = log_q + rp * torch.log(torch.tensor(q_min, device=g.device, dtype=torch.float32))
277
+ q = torch.exp(log_q)
278
+ p = torch.clamp(1.0 - q, 1e-12, 1.0 - 1e-12)
279
+ beta = torch.sqrt(torch.tensor(2.0, device=g.device, dtype=g.dtype)) * torch.special.erfinv(2.0 * p - 1.0)
280
+ beta = torch.clamp(beta, 0.0, 7.5)
281
+ return g - (beta * b_unit)
282
+
283
+ def _all_close_to_eos(self, beams, eos_emb: torch.Tensor) -> bool:
284
+ for (_, emb, _, _) in beams:
285
+ last = emb[-1:, :]
286
+ sim = F.cosine_similarity(last, eos_emb, dim=1).item()
287
+ if sim < float(self._cfg.eos_threshold):
288
+ return False
289
+ return True
290
+
291
+ # --- factory ---
292
+ @classmethod
293
+ def load_from_checkpoint(
294
+ cls,
295
+ checkpoint_dir: str,
296
+ device: Optional[torch.device] = None,
297
+ generation_config: Optional[SONARLLMGenerationConfig] = None,
298
+ ) -> "SONARLLMGenerator":
299
+ """Load generator from a folder with config.json and weights.
300
+
301
+ The folder is expected to contain:
302
+ - config.json (with keys: pretrained_model_name_or_path, llama_config?, embed_dim)
303
+ - pytorch_model.bin (or model_state_dict inside the saved file)
304
+ """
305
+ import json
306
+ import os
307
+ from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
308
+ from .embedding_to_text_with_scores import EmbeddingToTextModelPipeline
309
+ from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
310
+
311
+ if device is None:
312
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
313
+
314
+ cfg_path = os.path.join(checkpoint_dir, "config.json")
315
+ with open(cfg_path, "r", encoding="utf-8") as f:
316
+ cfg = json.load(f)
317
+
318
+ tokenizer = AutoTokenizer.from_pretrained(cfg["pretrained_model_name_or_path"])
319
+ tokenizer.pad_token = tokenizer.eos_token
320
+
321
+ llama_cfg_dict = cfg.get("llama_config", {})
322
+ llama_cfg_dict["vocab_size"] = len(tokenizer)
323
+ llama_cfg_dict["pad_token_id"] = tokenizer.pad_token_id
324
+ llama_cfg_dict["bos_token_id"] = tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 128000
325
+ llama_cfg_dict["eos_token_id"] = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 128001
326
+ llama_cfg = LlamaConfig(**llama_cfg_dict) if "llama_config" in cfg else LlamaConfig()
327
+
328
+ llama_model = LlamaForCausalLM(llama_cfg).to(device).eval()
329
+
330
+ hidden_size = llama_cfg.hidden_size
331
+ embed_dim = cfg.get("embed_dim", 1024)
332
+
333
+ t2vec_model = TextToEmbeddingModelPipeline(
334
+ encoder="text_sonar_basic_encoder",
335
+ tokenizer="text_sonar_basic_encoder",
336
+ device=device,
337
+ ).eval()
338
+
339
+ vec2text_model = EmbeddingToTextModelPipeline(
340
+ decoder="text_sonar_basic_decoder",
341
+ tokenizer="text_sonar_basic_encoder",
342
+ device=device,
343
+ ).eval()
344
+
345
+ forward_projector = Projector(embed_dim, hidden_size).to(device).eval()
346
+ reverse_projector = Projector(hidden_size, embed_dim).to(device).eval()
347
+
348
+ gen = cls(
349
+ llama_model,
350
+ forward_projector,
351
+ reverse_projector,
352
+ vec2text_model,
353
+ t2vec_model,
354
+ device
355
+ )
356
+
357
+ # Load weights into generator to cover llama + projectors
358
+ ckpt_bin = os.path.join(checkpoint_dir, "pytorch_model.bin")
359
+ state = torch.load(ckpt_bin, map_location=device)
360
+ state = state.get("model_state_dict", state)
361
+ raw = gen.module if hasattr(gen, "module") else gen
362
+ raw.load_state_dict(state, strict=False)
363
+
364
+ return gen
365
+
366
+
sonarllm_model/sonarllm_model/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .sonar_llm_model import SONARLLMGenerator, SONARLLMGenerationConfig
2
+
3
+
4
+
5
+
sonarllm_model/sonarllm_model/embedding_to_text_with_scores.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Iterable, List, Optional
3
+
4
+ import torch
5
+
6
+ from fairseq2.generation import (
7
+ BeamSearchSeq2SeqGenerator,
8
+ Sampler,
9
+ SamplingSeq2SeqGenerator,
10
+ Seq2SeqGenerator,
11
+ SequenceToTextConverter,
12
+ )
13
+
14
+ from sonar.inference_pipelines.utils import add_progress_bar
15
+ from sonar.inference_pipelines.text import (
16
+ EmbeddingToTextModelPipeline as _BaseEmbeddingToTextModelPipeline,
17
+ )
18
+ from fairseq2.data.data_pipeline import read_sequence
19
+
20
+
21
+ class EmbeddingToTextModelPipeline(_BaseEmbeddingToTextModelPipeline):
22
+ """Drop-in replacement that can also return sentence log-probabilities via return_scores.
23
+
24
+ - When return_scores=False (default), behaves exactly like the base pipeline and returns List[str].
25
+ - When return_scores=True, returns a tuple (List[str], List[float]) where each float is the
26
+ hypothesis score from fairseq2 (sum of token log-probabilities if normalize_scores=False,
27
+ otherwise length-normalized per fairseq2 semantics).
28
+ """
29
+
30
+ @torch.inference_mode()
31
+ def predict(
32
+ self,
33
+ inputs: torch.Tensor,
34
+ target_lang: str,
35
+ batch_size: int = 5,
36
+ progress_bar: bool = False,
37
+ sampler: Optional[Sampler] = None,
38
+ return_scores: bool = False,
39
+ **generator_kwargs,
40
+ ):
41
+ if sampler is not None:
42
+ generator: Seq2SeqGenerator = SamplingSeq2SeqGenerator(
43
+ self.model, sampler, **generator_kwargs
44
+ )
45
+ else:
46
+ generator = BeamSearchSeq2SeqGenerator(self.model, **generator_kwargs)
47
+
48
+ converter = SequenceToTextConverter(
49
+ generator,
50
+ self.tokenizer,
51
+ task="translation",
52
+ target_lang=target_lang,
53
+ )
54
+
55
+ def _do_translate(src_tensors: List[torch.Tensor]):
56
+ texts, gen_out = converter.batch_convert(
57
+ torch.stack(src_tensors).to(self.device), None
58
+ )
59
+ if return_scores:
60
+ scores: List[float] = []
61
+ for hyps in gen_out.hypotheses:
62
+ if len(hyps) == 0 or hyps[0].score is None:
63
+ scores.append(0.0)
64
+ else:
65
+ scores.append(float(hyps[0].score))
66
+ return texts, scores
67
+ return texts
68
+
69
+ pipeline: Iterable = (
70
+ read_sequence(list(inputs))
71
+ .bucket(batch_size)
72
+ .map(_do_translate)
73
+ .and_return()
74
+ )
75
+
76
+ if progress_bar:
77
+ pipeline = add_progress_bar(pipeline, inputs=inputs, batch_size=batch_size)
78
+
79
+ results: List = list(iter(pipeline))
80
+
81
+ if not return_scores:
82
+ # results is List[List[str]] → flatten
83
+ return [text for batch_texts in results for text in batch_texts]
84
+
85
+ # results is List[Tuple[List[str], List[float]]] → flatten both
86
+ all_texts: List[str] = []
87
+ all_scores: List[float] = []
88
+ for batch in results:
89
+ batch_texts, batch_scores = batch
90
+ all_texts.extend(batch_texts)
91
+ all_scores.extend(batch_scores)
92
+ return all_texts, all_scores
93
+
94
+
sonarllm_model/sonarllm_model/sonar_llm_model.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Tuple
4
+
5
+
6
+ import nltk
7
+ from nltk.tokenize import sent_tokenize
8
+ nltk.download("punkt", quiet=True)
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from .embedding_to_text_with_scores import EmbeddingToTextModelPipeline
15
+ from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
16
+
17
+ class Projector(nn.Module):
18
+ def __init__(self, in_dim: int, out_dim: int):
19
+ super().__init__()
20
+ self.linear = nn.Linear(in_dim, out_dim)
21
+
22
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
23
+ return self.linear(x)
24
+
25
+ @dataclass
26
+ class SONARLLMGenerationConfig:
27
+ # Outer sentence-level beam
28
+ sentence_beam_size: int = 4
29
+ latent_samples_per_step: int = 4 # M latent variants per active beam state
30
+
31
+ # Token-level decoder params
32
+ decoder_beam_size: int = 5 # default in fairseq2
33
+ decoder_temperature: float = 1.0 # default in fairseq2
34
+ normalize_sentence_scores: bool = True # False → sum of token log-probs
35
+ decoder_max_len: int = 256
36
+
37
+ # Latent sampling
38
+ temperature: float = 0.4
39
+ latent_top_p: Optional[float] = None # 0<p<=1 or None for Gaussian
40
+ temperature_mode: str = "relative" # "absolute" | "relative"
41
+
42
+ # Repetition control in latent space
43
+ repetition_penalty: float = 0.0
44
+ repetition_memory: int = 0
45
+
46
+ # Termination
47
+ max_sentences: int = 32
48
+ eos_threshold: float = 0.98
49
+
50
+
51
+ class SONARLLMGenerator(torch.nn.Module):
52
+ """Sentence-level beam over latent reversed embeddings using SONAR decoder.
53
+
54
+ For each step:
55
+ - Run LLaMA on the sentence embedding history to get final hidden.
56
+ - Sample multiple latent directions (temperature/latent_top_p, with repetition penalty).
57
+ - Project to `reversed_emb` and decode text via SONAR decoder.
58
+ - Score each candidate using decoder sentence logprob (+ optional shaping).
59
+ - Keep top `sentence_beam_size` states and continue until EOS or max sentences.
60
+
61
+ This class does NOT modify existing project files and can be used standalone.
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ llama_model: nn.Module,
67
+ forward_proj: nn.Module,
68
+ reverse_proj: nn.Module,
69
+ sonar_decoder: EmbeddingToTextModelPipeline,
70
+ t2vec_model: TextToEmbeddingModelPipeline,
71
+ device: torch.device,
72
+ ) -> None:
73
+ super().__init__()
74
+ self.llama_model = llama_model.eval()
75
+ self.forward_proj = forward_proj.eval()
76
+ self.reverse_proj = reverse_proj.eval()
77
+ self.sonar_decoder = sonar_decoder.eval()
78
+ self.t2vec = t2vec_model.eval()
79
+ self.device = device
80
+
81
+ @torch.no_grad()
82
+ def generate(self, prefix_text: str, eos_emb: torch.Tensor, cfg: Optional[SONARLLMGenerationConfig] = None) -> str:
83
+ # Normalize and attach config to the instance for helper use
84
+ if cfg is None:
85
+ cfg = SONARLLMGenerationConfig()
86
+ self._cfg = cfg
87
+ sents = sent_tokenize(prefix_text)
88
+ if len(sents) == 0:
89
+ sents = [prefix_text.strip()]
90
+
91
+ # Initialize prefix embeddings
92
+ emb_seq = self.t2vec.predict(sents, source_lang="eng_Latn").to(self.device)
93
+
94
+ # Beam state tuple: (sentences, embeddings_seq, cumulative_score, recent_dirs)
95
+ beams: List[Tuple[List[str], torch.Tensor, float, List[torch.Tensor]]] = [
96
+ (sents[:], emb_seq, 0.0, [])
97
+ ]
98
+
99
+ steps = 0
100
+ while steps < self._cfg.max_sentences:
101
+ steps += 1
102
+ candidates: List[Tuple[List[str], torch.Tensor, float, List[torch.Tensor]]] = []
103
+
104
+ for (hist_sents, hist_emb, score, recent_dirs) in beams:
105
+ candidates.extend(
106
+ self._expand_beam_state(hist_sents, hist_emb, score, recent_dirs, eos_emb)
107
+ )
108
+
109
+ # Keep top-k beams
110
+ if len(candidates) == 0:
111
+ break
112
+ candidates.sort(key=lambda b: b[2], reverse=True)
113
+ beams = candidates[: int(self._cfg.sentence_beam_size)]
114
+
115
+ # If all beams look ended by EOS threshold, stop early
116
+ if self._all_close_to_eos(beams, eos_emb):
117
+ break
118
+
119
+ best = max(beams, key=lambda b: b[2])
120
+ return self._join_sentences(best[0])
121
+
122
+ # --- internals ---
123
+
124
+ @torch.no_grad()
125
+ def _forward_hidden(self, emb_seq: torch.Tensor) -> torch.Tensor:
126
+ proj = self.forward_proj(emb_seq.unsqueeze(0)) if emb_seq.ndim == 2 else self.forward_proj(emb_seq)
127
+ out = self.llama_model(inputs_embeds=proj, output_hidden_states=True)
128
+ hidden = out.hidden_states[-1]
129
+ return hidden[0, -1, :]
130
+
131
+ def _join_sentences(self, sents: List[str]) -> str:
132
+ return " ".join(sents)
133
+
134
+ def _update_recent_dirs(
135
+ self, recent: List[torch.Tensor], u: torch.Tensor, memory_cap: int
136
+ ) -> List[torch.Tensor]:
137
+ if memory_cap <= 0:
138
+ return recent
139
+ if not torch.isfinite(u).all():
140
+ return recent
141
+ new_recent = recent + [u.detach().to("cpu")]
142
+ if len(new_recent) > int(memory_cap):
143
+ new_recent = new_recent[-int(memory_cap) :]
144
+ return new_recent
145
+
146
+ def _sample_noise_direction(
147
+ self, final_hidden: torch.Tensor, recent_dirs: List[torch.Tensor]
148
+ ) -> torch.Tensor:
149
+ g = torch.randn_like(final_hidden)
150
+ if (
151
+ self._cfg.repetition_penalty is not None
152
+ and float(self._cfg.repetition_penalty) != 1.0
153
+ and self._cfg.repetition_memory > 0
154
+ and len(recent_dirs) > 0
155
+ ):
156
+ g = self._apply_repetition_penalty_to_direction(
157
+ g, float(self._cfg.repetition_penalty), int(self._cfg.repetition_memory), recent_dirs
158
+ )
159
+ return g / (g.norm(p=2) + 1e-12)
160
+
161
+ def _sample_noise(
162
+ self, final_hidden: torch.Tensor, dir_unit: torch.Tensor
163
+ ) -> torch.Tensor:
164
+ t = float(self._cfg.temperature)
165
+ if t <= 0.0:
166
+ return torch.zeros_like(final_hidden)
167
+
168
+ if self._cfg.temperature_mode not in ("absolute", "relative"):
169
+ raise ValueError(f"Unsupported temperature_mode: {self._cfg.temperature_mode}")
170
+
171
+ if self._cfg.temperature_mode == "absolute":
172
+ sigma = torch.tensor(t, device=final_hidden.device, dtype=final_hidden.dtype)
173
+ else:
174
+ rms = torch.sqrt(torch.mean(final_hidden.to(torch.float32) ** 2))
175
+ rms = torch.clamp(rms, min=1e-12).to(dtype=final_hidden.dtype, device=final_hidden.device)
176
+ sigma = rms * t
177
+
178
+ top_p = self._cfg.latent_top_p
179
+ if top_p is None:
180
+ top_p = 1.0
181
+ return self._sample_truncated_normal_like(final_hidden, float(top_p), sigma, dir_unit)
182
+
183
+ def _sample_truncated_normal_like(
184
+ self, base_vector: torch.Tensor, top_p: float, sigma: torch.Tensor, dir_unit: torch.Tensor
185
+ ) -> torch.Tensor:
186
+ # Wilson–Hilferty approximation for ChiSquare quantiles
187
+ dim = base_vector.numel()
188
+ device = base_vector.device
189
+ u = torch.rand((), device=device, dtype=torch.float32)
190
+ p = torch.clamp(u * float(top_p), min=1e-12, max=1.0 - 1e-12)
191
+ k = torch.tensor(float(dim), device=device, dtype=torch.float32)
192
+ z = torch.sqrt(torch.tensor(2.0, device=device, dtype=torch.float32)) * torch.special.erfinv(2.0 * p - 1.0)
193
+ term = 1.0 - 2.0 / (9.0 * k) + z * torch.sqrt(2.0 / (9.0 * k))
194
+ term = torch.clamp(term, min=1e-12)
195
+ s = k * (term ** 3)
196
+ r = torch.sqrt(torch.clamp(s, min=1e-12)).to(dtype=base_vector.dtype)
197
+ return dir_unit * (r * sigma)
198
+
199
+ def _expand_beam_state(
200
+ self,
201
+ hist_sents: List[str],
202
+ hist_emb: torch.Tensor,
203
+ score: float,
204
+ recent_dirs: List[torch.Tensor],
205
+ eos_emb: torch.Tensor,
206
+ ) -> List[Tuple[List[str], torch.Tensor, float, List[torch.Tensor]]]:
207
+ """Expand one beam state into candidate next states.
208
+
209
+ Returns a list of (new_hist_sents, new_hist_emb, new_score, new_recent_dirs).
210
+ """
211
+ final_hidden = self._forward_hidden(hist_emb)
212
+ out: List[Tuple[List[str], torch.Tensor, float, List[torch.Tensor]]] = []
213
+
214
+ for _ in range(max(1, int(self._cfg.latent_samples_per_step))):
215
+ dir_unit = self._sample_noise_direction(final_hidden, recent_dirs)
216
+ noise = self._sample_noise(final_hidden, dir_unit)
217
+ h_perturbed = final_hidden + noise
218
+ z = self.reverse_proj(h_perturbed.unsqueeze(0))
219
+
220
+ texts, scores = self.sonar_decoder.predict(
221
+ z,
222
+ target_lang="eng_Latn",
223
+ beam_size=int(self._cfg.decoder_beam_size),
224
+ normalize_scores=bool(self._cfg.normalize_sentence_scores),
225
+ max_seq_len=self._cfg.decoder_max_len,
226
+ temperature=float(self._cfg.decoder_temperature),
227
+ return_scores=True,
228
+ )
229
+ text = texts[0]
230
+ sent_logprob = float(scores[0])
231
+
232
+ z_re = self.t2vec.predict([text], source_lang="eng_Latn").to(self.device)
233
+
234
+ cand_score = score + sent_logprob
235
+ new_recent = self._update_recent_dirs(recent_dirs, dir_unit, self._cfg.repetition_memory)
236
+
237
+ new_hist_sents = hist_sents + [text]
238
+ new_hist_emb = torch.cat([hist_emb, z_re], dim=0)
239
+
240
+ out.append((new_hist_sents, new_hist_emb, cand_score, new_recent))
241
+
242
+ return out
243
+
244
+ def _apply_repetition_penalty_to_direction(
245
+ self, g: torch.Tensor, penalty: float, memory_cap: int, recent_dirs: List[torch.Tensor]
246
+ ) -> torch.Tensor:
247
+ """Mean-shift (A+) repetition penalty in latent direction space.
248
+
249
+ - penalty is clamped to [0, 1].
250
+ - penalty = 0 → no shift (q = 0.5).
251
+ - penalty = 1 → maximum shift (q ≈ q_min).
252
+ Mapping: q = 0.5^(1-penalty) * q_min^(penalty), beta = Phi^{-1}(1 - q),
253
+ and we set g' = g - beta * b_unit, where b_unit is the normalized average of recent directions.
254
+ """
255
+ if memory_cap <= 0 or len(recent_dirs) == 0:
256
+ return g
257
+
258
+ # Aggregate and normalize recent directions
259
+ B = torch.stack(
260
+ [u.to(device=g.device, dtype=g.dtype) for u in recent_dirs[-int(memory_cap):]], dim=0
261
+ )
262
+ b = B.mean(dim=0)
263
+ bn = b.norm(p=2)
264
+ if not torch.isfinite(bn) or bn <= 1e-12:
265
+ return g
266
+ b_unit = b / bn
267
+
268
+ # Clamp and map penalty → beta via q
269
+ rp = float(penalty)
270
+ if rp < 0.0:
271
+ rp = 0.0
272
+ if rp > 1.0:
273
+ rp = 1.0
274
+ q_min = 1e-12
275
+ log_q = (1.0 - rp) * torch.log(torch.tensor(0.5, device=g.device, dtype=torch.float32))
276
+ log_q = log_q + rp * torch.log(torch.tensor(q_min, device=g.device, dtype=torch.float32))
277
+ q = torch.exp(log_q)
278
+ p = torch.clamp(1.0 - q, 1e-12, 1.0 - 1e-12)
279
+ beta = torch.sqrt(torch.tensor(2.0, device=g.device, dtype=g.dtype)) * torch.special.erfinv(2.0 * p - 1.0)
280
+ beta = torch.clamp(beta, 0.0, 7.5)
281
+ return g - (beta * b_unit)
282
+
283
+ def _all_close_to_eos(self, beams, eos_emb: torch.Tensor) -> bool:
284
+ for (_, emb, _, _) in beams:
285
+ last = emb[-1:, :]
286
+ sim = F.cosine_similarity(last, eos_emb, dim=1).item()
287
+ if sim < float(self._cfg.eos_threshold):
288
+ return False
289
+ return True
290
+
291
+ # --- factory ---
292
+ @classmethod
293
+ def load_from_checkpoint(
294
+ cls,
295
+ checkpoint_dir: str,
296
+ device: Optional[torch.device] = None,
297
+ generation_config: Optional[SONARLLMGenerationConfig] = None,
298
+ ) -> "SONARLLMGenerator":
299
+ """Load generator from a folder with config.json and weights.
300
+
301
+ The folder is expected to contain:
302
+ - config.json (with keys: pretrained_model_name_or_path, llama_config?, embed_dim)
303
+ - pytorch_model.bin (or model_state_dict inside the saved file)
304
+ """
305
+ import json
306
+ import os
307
+ from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
308
+ from .embedding_to_text_with_scores import EmbeddingToTextModelPipeline
309
+ from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
310
+
311
+ if device is None:
312
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
313
+
314
+ cfg_path = os.path.join(checkpoint_dir, "config.json")
315
+ with open(cfg_path, "r", encoding="utf-8") as f:
316
+ cfg = json.load(f)
317
+
318
+ tokenizer = AutoTokenizer.from_pretrained(cfg["pretrained_model_name_or_path"])
319
+ tokenizer.pad_token = tokenizer.eos_token
320
+
321
+ llama_cfg_dict = cfg.get("llama_config", {})
322
+ llama_cfg_dict["vocab_size"] = len(tokenizer)
323
+ llama_cfg_dict["pad_token_id"] = tokenizer.pad_token_id
324
+ llama_cfg_dict["bos_token_id"] = tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 128000
325
+ llama_cfg_dict["eos_token_id"] = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 128001
326
+ llama_cfg = LlamaConfig(**llama_cfg_dict) if "llama_config" in cfg else LlamaConfig()
327
+
328
+ llama_model = LlamaForCausalLM(llama_cfg).to(device).eval()
329
+
330
+ hidden_size = llama_cfg.hidden_size
331
+ embed_dim = cfg.get("embed_dim", 1024)
332
+
333
+ t2vec_model = TextToEmbeddingModelPipeline(
334
+ encoder="text_sonar_basic_encoder",
335
+ tokenizer="text_sonar_basic_encoder",
336
+ device=device,
337
+ ).eval()
338
+
339
+ vec2text_model = EmbeddingToTextModelPipeline(
340
+ decoder="text_sonar_basic_decoder",
341
+ tokenizer="text_sonar_basic_encoder",
342
+ device=device,
343
+ ).eval()
344
+
345
+ forward_projector = Projector(embed_dim, hidden_size).to(device).eval()
346
+ reverse_projector = Projector(hidden_size, embed_dim).to(device).eval()
347
+
348
+ gen = cls(
349
+ llama_model,
350
+ forward_projector,
351
+ reverse_projector,
352
+ vec2text_model,
353
+ t2vec_model,
354
+ device
355
+ )
356
+
357
+ # Load weights into generator to cover llama + projectors
358
+ ckpt_bin = os.path.join(checkpoint_dir, "pytorch_model.bin")
359
+ state = torch.load(ckpt_bin, map_location=device)
360
+ state = state.get("model_state_dict", state)
361
+ raw = gen.module if hasattr(gen, "module") else gen
362
+ raw.load_state_dict(state, strict=False)
363
+
364
+ return gen
365
+
366
+