kashif HF Staff commited on
Commit
a42769a
·
verified ·
1 Parent(s): 5fdd203

Upload 2 files

Browse files
custom_generate/generate.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ from typing import Any, Optional, Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
8
+ from transformers.generation.logits_process import (
9
+ TemperatureLogitsWarper,
10
+ TopKLogitsWarper,
11
+ TopPLogitsWarper,
12
+ )
13
+ from transformers.generation.utils import GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput
14
+
15
+
16
+ def generate(
17
+ model: Any,
18
+ input_ids: torch.LongTensor,
19
+ logits_processor: Optional[LogitsProcessorList] = None,
20
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
21
+ generation_config: Optional[GenerationConfig] = None,
22
+ synced_gpus: bool = False,
23
+ streamer: Optional[Any] = None,
24
+ **model_kwargs,
25
+ ) -> Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, torch.LongTensor]:
26
+ """Custom decoding with DeepCONF (confidence-based early stopping).
27
+
28
+ Args:
29
+ model: PreTrainedModel with a LM head.
30
+ input_ids: Prompt ids of shape (batch, seq_len).
31
+ logits_processor: Optional logits processors.
32
+ stopping_criteria: Optional stopping criteria.
33
+ generation_config: GenerationConfig controlling sampling/outputs.
34
+ synced_gpus: Keep looping to max length for distributed setups.
35
+ streamer: Optional streamer for incremental tokens.
36
+ **model_kwargs: Forward pass kwargs (e.g., attention_mask).
37
+
38
+ Returns:
39
+ GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, or LongTensor
40
+ depending on `return_dict_in_generate` and model type.
41
+ """
42
+
43
+ # Get DeepCONF parameters from generation_config or set defaults
44
+ enable_conf = getattr(generation_config, "enable_conf", False)
45
+ window_size = getattr(generation_config, "window_size", 2048)
46
+ threshold = getattr(generation_config, "threshold", 17.0) # Default threshold for confidence (positive value)
47
+
48
+ # If DeepCONF is not enabled, fall back to standard sampling
49
+ if not enable_conf:
50
+ return model._sample(
51
+ input_ids,
52
+ logits_processor=logits_processor,
53
+ stopping_criteria=stopping_criteria,
54
+ generation_config=generation_config,
55
+ synced_gpus=synced_gpus,
56
+ streamer=streamer,
57
+ **model_kwargs,
58
+ )
59
+
60
+ # Initialize values
61
+ # Handle pad token properly (following HF best practices)
62
+ pad_token_id = generation_config.pad_token_id
63
+ if pad_token_id is None and hasattr(generation_config, "_pad_token_tensor"):
64
+ pad_token_id = generation_config._pad_token_tensor
65
+ if pad_token_id is None and hasattr(model.config, "pad_token_id"):
66
+ pad_token_id = model.config.pad_token_id
67
+ if pad_token_id is None and generation_config.eos_token_id is not None:
68
+ # Use eos token as pad token if not set
69
+ pad_token_id = generation_config.eos_token_id
70
+
71
+ output_attentions = generation_config.output_attentions
72
+ output_hidden_states = generation_config.output_hidden_states
73
+ output_scores = generation_config.output_scores
74
+ output_logits = generation_config.output_logits
75
+ return_dict_in_generate = generation_config.return_dict_in_generate
76
+ output_confidences = getattr(generation_config, "output_confidences", False)
77
+ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
78
+ do_sample = generation_config.do_sample
79
+
80
+ # Initialize attention / hidden states / scores tuples
81
+ scores = () if (return_dict_in_generate and output_scores) else None
82
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
83
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
84
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
85
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
86
+
87
+ # If model is an encoder-decoder, retrieve encoder attention weights and hidden states
88
+ if return_dict_in_generate and model.config.is_encoder_decoder:
89
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
90
+ encoder_hidden_states = model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
91
+
92
+ # Keep track of which sequences are already finished
93
+ batch_size, cur_len = input_ids.shape[:2]
94
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
95
+ # Use public kv-cache via past_key_values
96
+
97
+ # Initialize confidence tracking
98
+ # Use deque for sliding window with fixed size
99
+ conf_group_lists = [deque(maxlen=window_size) for _ in range(batch_size)]
100
+ conf_grouped_sums = [0.0 for _ in range(batch_size)] # Running sums for efficient mean calculation
101
+
102
+ # Initialize via prepare_inputs_for_generation
103
+
104
+ # Optional per-step confidences for debugging/visualization
105
+ step_confidences = [] if (return_dict_in_generate and output_confidences) else None
106
+
107
+ # Main generation loop using public controls
108
+ steps = 0
109
+ max_new_tokens = getattr(generation_config, "max_new_tokens", None) or 512
110
+ # Initialize cache_position for first forward over the full prompt
111
+ # Subsequent steps will pass a single position incrementally
112
+ model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
113
+ while steps < max_new_tokens and unfinished_sequences.max() != 0:
114
+ # Prepare model inputs (proper KV cache handling)
115
+ model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
116
+
117
+ # Prepare variable output controls
118
+ model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
119
+ model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
120
+
121
+ # Forward pass with proper KV cache handling
122
+ with torch.no_grad():
123
+ outputs = model(**model_inputs, return_dict=True)
124
+ next_token_logits = outputs.logits[:, -1, :].detach()
125
+
126
+ # Update model kwargs for next iteration (public): carry past_key_values
127
+ if hasattr(outputs, "past_key_values") and outputs.past_key_values is not None:
128
+ model_kwargs["past_key_values"] = outputs.past_key_values
129
+
130
+ # Pre-process distribution with logits processors
131
+ next_token_scores = logits_processor(input_ids, next_token_logits)
132
+
133
+ # Apply logits warpers (e.g., temperature, top-k, top-p) from generation_config
134
+ warpers = LogitsProcessorList()
135
+ # Temperature
136
+ temperature = getattr(generation_config, "temperature", 1.0)
137
+ if temperature is not None and temperature != 1.0:
138
+ warpers.append(TemperatureLogitsWarper(temperature))
139
+ # Top-k
140
+ top_k = getattr(generation_config, "top_k", None)
141
+ if top_k is not None and isinstance(top_k, int) and top_k > 0:
142
+ warpers.append(TopKLogitsWarper(top_k))
143
+ # Top-p
144
+ top_p = getattr(generation_config, "top_p", None)
145
+ if top_p is not None and top_p < 1.0:
146
+ warpers.append(TopPLogitsWarper(top_p))
147
+ if len(warpers) > 0:
148
+ next_token_scores = warpers(input_ids, next_token_scores)
149
+
150
+ # Store scores, attentions and hidden_states when required
151
+ if return_dict_in_generate:
152
+ if output_scores:
153
+ scores += (next_token_scores,)
154
+ if output_logits:
155
+ raw_logits += (next_token_logits,)
156
+ if output_attentions:
157
+ decoder_attentions += (
158
+ (outputs.decoder_attentions,) if model.config.is_encoder_decoder else (outputs.attentions,)
159
+ )
160
+ if model.config.is_encoder_decoder:
161
+ cross_attentions += (outputs.cross_attentions,)
162
+
163
+ if output_hidden_states:
164
+ decoder_hidden_states += (
165
+ (outputs.decoder_hidden_states,) if model.config.is_encoder_decoder else (outputs.hidden_states,)
166
+ )
167
+
168
+ # Token selection
169
+ if do_sample:
170
+ probs = F.softmax(next_token_scores, dim=-1)
171
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
172
+ else:
173
+ next_tokens = torch.argmax(next_token_scores, dim=-1)
174
+
175
+ # Calculate confidence using only top-k/top-p filtered candidates (post-logits processors),
176
+ # excluding the sampled token.
177
+ # We consider candidates where logits are finite after warpers (e.g., top-k/top-p/temperature).
178
+ logprobs = F.log_softmax(next_token_scores, dim=-1)
179
+ candidate_mask = torch.isfinite(next_token_scores)
180
+
181
+ deepconf_stopping = torch.ones(batch_size, dtype=torch.bool, device=input_ids.device)
182
+ step_conf_values = [0.0] * batch_size # collect per-sequence confidences for this step (full batch)
183
+
184
+ for i in range(batch_size):
185
+ if not unfinished_sequences[i]:
186
+ continue
187
+
188
+ # Count valid candidates
189
+ num_candidates = int(candidate_mask[i].sum().item())
190
+ if num_candidates <= 1:
191
+ conf = 0.0
192
+ else:
193
+ # Sum logprobs over valid candidates and exclude the sampled token's logprob
194
+ total_lp = torch.sum(logprobs[i][candidate_mask[i]])
195
+ selected_lp = (
196
+ logprobs[i, next_tokens[i]]
197
+ if candidate_mask[i, next_tokens[i]]
198
+ else torch.tensor(0.0, device=logprobs.device)
199
+ )
200
+ denom = num_candidates - 1
201
+ # Negative mean of non-selected candidate logprobs
202
+ conf = -((total_lp - selected_lp) / denom).item()
203
+
204
+ # Update tracking structures
205
+ if len(conf_group_lists[i]) >= window_size:
206
+ conf_grouped_sums[i] -= conf_group_lists[i][0]
207
+ conf_group_lists[i].append(conf)
208
+ conf_grouped_sums[i] += conf
209
+
210
+ # Apply confidence-based early stopping when window is full
211
+ if len(conf_group_lists[i]) >= window_size:
212
+ avg_conf = conf_grouped_sums[i] / len(conf_group_lists[i])
213
+ if avg_conf < threshold:
214
+ deepconf_stopping[i] = False
215
+
216
+ if step_confidences is not None:
217
+ step_conf_values[i] = conf
218
+
219
+ if step_confidences is not None:
220
+ # Store this step's confidences as a tensor of shape (batch,)
221
+ step_confidences.append(torch.tensor(step_conf_values, device=input_ids.device))
222
+
223
+ # Finished sentences should have their next token be a padding token
224
+ if has_eos_stopping_criteria and pad_token_id is not None:
225
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
226
+
227
+ # Update generated ids, model inputs, and length for next step
228
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
229
+ # Update attention mask if available
230
+ if model_kwargs.get("attention_mask") is not None:
231
+ attn = model_kwargs["attention_mask"]
232
+ model_kwargs["attention_mask"] = torch.cat(
233
+ [attn, torch.ones((batch_size, 1), dtype=attn.dtype, device=attn.device)], dim=-1
234
+ )
235
+ # Update cache_position for next step (single next token)
236
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
237
+ if streamer is not None:
238
+ streamer.put(next_tokens.cpu())
239
+
240
+ # Update unfinished sequences with standard stopping criteria (per-sequence if available)
241
+ sc = stopping_criteria(input_ids, scores)
242
+ if isinstance(sc, torch.Tensor):
243
+ unfinished_sequences = unfinished_sequences & ~sc
244
+ elif sc:
245
+ # global stop
246
+ unfinished_sequences = torch.zeros_like(unfinished_sequences)
247
+
248
+ # Apply DeepCONF stopping
249
+ unfinished_sequences = unfinished_sequences & deepconf_stopping
250
+
251
+ # Early break if all sequences finished and not synchronized
252
+ if unfinished_sequences.max() == 0 and not synced_gpus:
253
+ break
254
+ cur_len += 1
255
+ steps += 1
256
+
257
+ # Clean up outputs to save memory
258
+ del outputs
259
+
260
+ if streamer is not None:
261
+ streamer.end()
262
+
263
+ # Return results
264
+ if return_dict_in_generate:
265
+ # Prepare confidences tensor if requested
266
+ confidences_tensor = None
267
+ if step_confidences is not None and len(step_confidences) > 0:
268
+ # Shape: (steps, batch) -> (batch, steps)
269
+ confidences_tensor = torch.stack(step_confidences, dim=0).transpose(0, 1)
270
+ if model.config.is_encoder_decoder:
271
+ output = GenerateEncoderDecoderOutput(
272
+ sequences=input_ids,
273
+ scores=scores,
274
+ logits=raw_logits,
275
+ encoder_attentions=encoder_attentions,
276
+ encoder_hidden_states=encoder_hidden_states,
277
+ decoder_attentions=decoder_attentions,
278
+ cross_attentions=cross_attentions,
279
+ decoder_hidden_states=decoder_hidden_states,
280
+ past_key_values=model_kwargs.get("past_key_values"),
281
+ )
282
+ if confidences_tensor is not None:
283
+ output["confidences"] = confidences_tensor
284
+ try:
285
+ setattr(output, "confidences", confidences_tensor)
286
+ except Exception:
287
+ pass
288
+ return output
289
+ else:
290
+ output = GenerateDecoderOnlyOutput(
291
+ sequences=input_ids,
292
+ scores=scores,
293
+ logits=raw_logits,
294
+ attentions=decoder_attentions,
295
+ hidden_states=decoder_hidden_states,
296
+ past_key_values=model_kwargs.get("past_key_values"),
297
+ )
298
+ if confidences_tensor is not None:
299
+ output["confidences"] = confidences_tensor
300
+ try:
301
+ setattr(output, "confidences", confidences_tensor)
302
+ except Exception:
303
+ pass
304
+ return output
305
+ else:
306
+ return input_ids
custom_generate/requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # DeepCONF custom generation strategy requirements
2
+ # This implementation only uses PyTorch and Transformers, which should already be available
3
+ torch>=1.13.0
4
+ transformers>=4.35.0