habdine commited on
Commit
794a213
1 Parent(s): 6ccdea6

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +589 -0
utils.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP
3
+ from typing import Optional, Tuple, Union, List
4
+ from transformers import GPT2LMHeadModel
5
+ import torch
6
+ from transformers.deepspeed import is_deepspeed_zero3_enabled
7
+ from transformers.generation.logits_process import LogitsProcessorList
8
+ from transformers.generation.stopping_criteria import StoppingCriteriaList
9
+ from transformers.generation.utils import GreedySearchOutput, GreedySearchEncoderDecoderOutput, BeamSearchOutput, BeamSearchEncoderDecoderOutput
10
+ from transformers.generation.beam_search import BeamScorer
11
+
12
+
13
+ class _GPT2LMHeadModel(GPT2LMHeadModel):
14
+ def _init_(self, config):
15
+ super(GPT2LMHeadModel, self).init_(config)
16
+ self.config = config
17
+
18
+
19
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, encoder_outputs=None, **kwargs):
20
+ '''
21
+ This function is an edited version of the prepare_inputs_for_generation function from HuggingFace's transformers
22
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
23
+ '''
24
+ token_type_ids = kwargs.get("token_type_ids", None)
25
+ # only last token for inputs_ids if past is defined in kwargs
26
+ if past_key_values:
27
+ input_ids = input_ids[:, -1].unsqueeze(-1)
28
+ if token_type_ids is not None:
29
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
30
+
31
+ attention_mask = kwargs.get("attention_mask", None)
32
+ position_ids = kwargs.get("position_ids", None)
33
+ if self.config.prot2text_version=="1.1" or self.config.prot2text_version=="1.2":
34
+ encoder_attention_mask = kwargs.get("encoder_attention_mask", None)
35
+ elif self.config.prot2text_version=="1.0":
36
+ encoder_attention_mask = None
37
+
38
+ if attention_mask is not None and position_ids is None:
39
+ position_ids = attention_mask.long().cumsum(-1) - 1
40
+ position_ids.masked_fill_(attention_mask == 0, 1)
41
+ if past_key_values:
42
+ position_ids = position_ids[:, -1].unsqueeze(-1)
43
+ else:
44
+ position_ids = None
45
+
46
+ model_specific_kwargs = {
47
+ "encoder_hidden_states": encoder_outputs['hidden_states'],
48
+ }
49
+
50
+ return {
51
+ "input_ids": input_ids,
52
+ "past_key_values": past_key_values,
53
+ "use_cache": kwargs.get("use_cache"),
54
+ "position_ids": position_ids,
55
+ "attention_mask": attention_mask,
56
+ "token_type_ids": token_type_ids,
57
+ "encoder_attention_mask": encoder_attention_mask,
58
+ **model_specific_kwargs
59
+ }
60
+
61
+
62
+ def greedy_search(
63
+ self,
64
+ input_ids: torch.LongTensor,
65
+ logits_processor: Optional[LogitsProcessorList] = None,
66
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
67
+ max_length: Optional[int] = None,
68
+ pad_token_id: Optional[int] = None,
69
+ eos_token_id: Optional[Union[int, List[int]]] = None,
70
+ output_attentions: Optional[bool] = None,
71
+ output_hidden_states: Optional[bool] = None,
72
+ output_scores: Optional[bool] = None,
73
+ return_dict_in_generate: Optional[bool] = None,
74
+ synced_gpus: bool = False,
75
+ streamer: Optional["BaseStreamer"] = None,
76
+ **model_kwargs,
77
+ ) -> Union[GreedySearchOutput, torch.LongTensor]:
78
+ '''
79
+ This function is an edited version of the greedy_search function from HuggingFace's transformers
80
+ https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py
81
+ '''
82
+
83
+ # init values
84
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
85
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
86
+ if max_length is not None:
87
+ warnings.warn(
88
+ "`max_length` is deprecated in this function, use"
89
+ " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
90
+ UserWarning,
91
+ )
92
+ stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
93
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
94
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
95
+ if isinstance(eos_token_id, int):
96
+ eos_token_id = [eos_token_id]
97
+ eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
98
+ output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
99
+ output_attentions = (
100
+ output_attentions if output_attentions is not None else self.generation_config.output_attentions
101
+ )
102
+ output_hidden_states = (
103
+ output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
104
+ )
105
+ return_dict_in_generate = (
106
+ return_dict_in_generate
107
+ if return_dict_in_generate is not None
108
+ else self.generation_config.return_dict_in_generate
109
+ )
110
+
111
+ # init attention / hidden states / scores tuples
112
+ scores = () if (return_dict_in_generate and output_scores) else None
113
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
114
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
115
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
116
+
117
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
118
+ if return_dict_in_generate and self.config.is_encoder_decoder:
119
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
120
+ encoder_hidden_states = (
121
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
122
+ )
123
+
124
+ # keep track of which sequences are already finished
125
+ unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
126
+
127
+ this_peer_finished = False # used by synced_gpus only
128
+ while True:
129
+ if synced_gpus:
130
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
131
+ # The following logic allows an early break if all peers finished generating their sequence
132
+ this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
133
+ # send 0.0 if we finished, 1.0 otherwise
134
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
135
+ # did all peers finish? the reduced sum will be 0.0 then
136
+ if this_peer_finished_flag.item() == 0.0:
137
+ break
138
+
139
+ # prepare model inputs
140
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
141
+
142
+ # forward pass to get next token
143
+ outputs = self(
144
+ **model_inputs,
145
+ return_dict=True,
146
+ output_attentions=output_attentions,
147
+ output_hidden_states=output_hidden_states,
148
+ )
149
+
150
+ if synced_gpus and this_peer_finished:
151
+ continue # don't waste resources running the code we don't need
152
+
153
+ next_token_logits = outputs.logits[:, -1, :]
154
+
155
+ # pre-process distribution
156
+ next_tokens_scores = logits_processor(input_ids, next_token_logits)
157
+
158
+ # Store scores, attentions and hidden_states when required
159
+ if return_dict_in_generate:
160
+ if output_scores:
161
+ scores += (next_tokens_scores,)
162
+ if output_attentions:
163
+ decoder_attentions += (
164
+ (outputs.decoder_attentions,) if not self.config.is_encoder_decoder else (outputs.attentions,)
165
+ )
166
+ if self.config.is_encoder_decoder:
167
+ cross_attentions += (outputs.cross_attentions,)
168
+
169
+ if output_hidden_states:
170
+ decoder_hidden_states += (
171
+ (outputs.decoder_hidden_states,)
172
+ if self.config.is_encoder_decoder
173
+ else (outputs.hidden_states,)
174
+ )
175
+
176
+ # argmax
177
+ next_tokens = torch.argmax(next_tokens_scores, dim=-1)
178
+
179
+ # finished sentences should have their next token be a padding token
180
+ if eos_token_id is not None:
181
+ if pad_token_id is None:
182
+ raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
183
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
184
+
185
+ # update generated ids, model inputs, and length for next step
186
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
187
+ if streamer is not None:
188
+ streamer.put(next_tokens.cpu())
189
+ model_kwargs = self._update_model_kwargs_for_generation(
190
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
191
+ )
192
+
193
+ # if eos_token was found in one sentence, set sentence to finished
194
+ if eos_token_id_tensor is not None:
195
+ unfinished_sequences = unfinished_sequences.mul(
196
+ next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
197
+ )
198
+
199
+ # stop when each sentence is finished
200
+ if unfinished_sequences.max() == 0:
201
+ this_peer_finished = True
202
+
203
+ # stop if we exceed the maximum length
204
+ try:
205
+ if stopping_criteria(input_ids, scores):
206
+ this_peer_finished = True
207
+ except:
208
+ if all(stopping_criteria(input_ids, scores)):
209
+ this_peer_finished = True
210
+
211
+ if this_peer_finished and not synced_gpus:
212
+ break
213
+
214
+ if streamer is not None:
215
+ streamer.end()
216
+
217
+ if return_dict_in_generate:
218
+ if self.config.is_encoder_decoder:
219
+ return GreedySearchEncoderDecoderOutput(
220
+ sequences=input_ids,
221
+ scores=scores,
222
+ encoder_attentions=encoder_attentions,
223
+ encoder_hidden_states=encoder_hidden_states,
224
+ decoder_attentions=decoder_attentions,
225
+ cross_attentions=cross_attentions,
226
+ decoder_hidden_states=decoder_hidden_states,
227
+ )
228
+ else:
229
+ return GreedySearchDecoderOnlyOutput(
230
+ sequences=input_ids,
231
+ scores=scores,
232
+ attentions=decoder_attentions,
233
+ hidden_states=decoder_hidden_states,
234
+ )
235
+ else:
236
+ return input_ids
237
+
238
+ def _greedy_search(
239
+ self,
240
+ input_ids: torch.LongTensor,
241
+ logits_processor: Optional[LogitsProcessorList] = None,
242
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
243
+ max_length: Optional[int] = None,
244
+ pad_token_id: Optional[int] = None,
245
+ eos_token_id: Optional[Union[int, List[int]]] = None,
246
+ output_attentions: Optional[bool] = None,
247
+ output_hidden_states: Optional[bool] = None,
248
+ output_scores: Optional[bool] = None,
249
+ return_dict_in_generate: Optional[bool] = None,
250
+ synced_gpus: bool = False,
251
+ streamer: Optional["BaseStreamer"] = None,
252
+ **model_kwargs,
253
+ ) -> Union[GreedySearchOutput, torch.LongTensor]:
254
+
255
+ return self.greedy_search(
256
+ input_ids,
257
+ logits_processor,
258
+ stopping_criteria,
259
+ max_length,
260
+ pad_token_id,
261
+ eos_token_id,
262
+ output_attentions,
263
+ output_hidden_states,
264
+ output_scores,
265
+ return_dict_in_generate,
266
+ synced_gpus,
267
+ streamer,
268
+ **model_kwargs,
269
+ )
270
+ def _beam_search(
271
+ self,
272
+ input_ids: torch.LongTensor,
273
+ beam_scorer: BeamScorer,
274
+ logits_processor: Optional[LogitsProcessorList] = None,
275
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
276
+ max_length: Optional[int] = None,
277
+ pad_token_id: Optional[int] = None,
278
+ eos_token_id: Optional[Union[int, List[int]]] = None,
279
+ output_attentions: Optional[bool] = None,
280
+ output_hidden_states: Optional[bool] = None,
281
+ output_scores: Optional[bool] = None,
282
+ return_dict_in_generate: Optional[bool] = None,
283
+ synced_gpus: bool = False,
284
+ **model_kwargs,
285
+ ) -> Union[BeamSearchOutput, torch.LongTensor]:
286
+
287
+ return self.beam_search(
288
+ input_ids,
289
+ beam_scorer,
290
+ logits_processor,
291
+ stopping_criteria,
292
+ max_length,
293
+ pad_token_id,
294
+ eos_token_id,
295
+ output_attentions,
296
+ output_hidden_states,
297
+ output_scores,
298
+ return_dict_in_generate,
299
+ synced_gpus,
300
+ **model_kwargs,
301
+ )
302
+
303
+ def beam_search(
304
+ self,
305
+ input_ids: torch.LongTensor,
306
+ beam_scorer: BeamScorer,
307
+ logits_processor: Optional[LogitsProcessorList] = None,
308
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
309
+ max_length: Optional[int] = None,
310
+ pad_token_id: Optional[int] = None,
311
+ eos_token_id: Optional[Union[int, List[int]]] = None,
312
+ output_attentions: Optional[bool] = None,
313
+ output_hidden_states: Optional[bool] = None,
314
+ output_scores: Optional[bool] = None,
315
+ return_dict_in_generate: Optional[bool] = None,
316
+ synced_gpus: bool = False,
317
+ **model_kwargs,
318
+ ) -> Union[BeamSearchOutput, torch.LongTensor]:
319
+ '''
320
+ This function is an edited version of the beam_search function from HuggingFace's transformers
321
+ https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py
322
+ '''
323
+ # init values
324
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
325
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
326
+ if max_length is not None:
327
+ warnings.warn(
328
+ "`max_length` is deprecated in this function, use"
329
+ " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
330
+ UserWarning,
331
+ )
332
+ stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
333
+ if len(stopping_criteria) == 0:
334
+ warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
335
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
336
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
337
+ if isinstance(eos_token_id, int):
338
+ eos_token_id = [eos_token_id]
339
+ output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
340
+ output_attentions = (
341
+ output_attentions if output_attentions is not None else self.generation_config.output_attentions
342
+ )
343
+ output_hidden_states = (
344
+ output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
345
+ )
346
+ return_dict_in_generate = (
347
+ return_dict_in_generate
348
+ if return_dict_in_generate is not None
349
+ else self.generation_config.return_dict_in_generate
350
+ )
351
+
352
+ batch_size = len(beam_scorer._beam_hyps)
353
+ num_beams = beam_scorer.num_beams
354
+
355
+ batch_beam_size, cur_len = input_ids.shape
356
+
357
+ if num_beams * batch_size != batch_beam_size:
358
+ raise ValueError(
359
+ f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
360
+ )
361
+
362
+ # init attention / hidden states / scores tuples
363
+ scores = () if (return_dict_in_generate and output_scores) else None
364
+ beam_indices = (
365
+ tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
366
+ )
367
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
368
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
369
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
370
+
371
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
372
+ if return_dict_in_generate and self.config.is_encoder_decoder:
373
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
374
+ encoder_hidden_states = (
375
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
376
+ )
377
+
378
+ # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
379
+ # of the first beam are considered to avoid sampling the exact same tokens across all beams.
380
+ beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
381
+ beam_scores[:, 1:] = -1e9
382
+ beam_scores = beam_scores.view((batch_size * num_beams,))
383
+
384
+ this_peer_finished = False # used by synced_gpus only
385
+ while True:
386
+ if synced_gpus:
387
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
388
+ # The following logic allows an early break if all peers finished generating their sequence
389
+ this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
390
+ # send 0.0 if we finished, 1.0 otherwise
391
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
392
+ # did all peers finish? the reduced sum will be 0.0 then
393
+ if this_peer_finished_flag.item() == 0.0:
394
+ break
395
+
396
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
397
+
398
+ outputs = self(
399
+ **model_inputs,
400
+ return_dict=True,
401
+ output_attentions=output_attentions,
402
+ output_hidden_states=output_hidden_states,
403
+ )
404
+
405
+ if synced_gpus and this_peer_finished:
406
+ cur_len = cur_len + 1
407
+ continue # don't waste resources running the code we don't need
408
+
409
+ next_token_logits = outputs.logits[:, -1, :]
410
+ # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
411
+ # cannot be generated both before and after the `nn.functional.log_softmax` operation.
412
+ # next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
413
+ next_token_scores = nn.functional.log_softmax(
414
+ next_token_logits, dim=-1
415
+ ) # (batch_size * num_beams, vocab_size)
416
+
417
+ next_token_scores_processed = logits_processor(input_ids, next_token_scores)
418
+ # next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
419
+ next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
420
+ next_token_scores_processed
421
+ )
422
+
423
+ # Store scores, attentions and hidden_states when required
424
+ if return_dict_in_generate:
425
+ if output_scores:
426
+ scores += (next_token_scores_processed,)
427
+ if output_attentions:
428
+ decoder_attentions += (
429
+ (outputs.decoder_attentions,) if not self.config.is_encoder_decoder else (outputs.attentions,)
430
+ )
431
+ if self.config.is_encoder_decoder:
432
+ cross_attentions += (outputs.cross_attentions,)
433
+
434
+ if output_hidden_states:
435
+ decoder_hidden_states += (
436
+ (outputs.decoder_hidden_states,)
437
+ if self.config.is_encoder_decoder
438
+ else (outputs.hidden_states,)
439
+ )
440
+
441
+ # reshape for beam search
442
+ vocab_size = next_token_scores.shape[-1]
443
+ next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
444
+
445
+
446
+
447
+ # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
448
+ next_token_scores, next_tokens = torch.topk(
449
+ next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
450
+ )
451
+
452
+ next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
453
+ next_tokens = next_tokens % vocab_size
454
+
455
+ # stateless
456
+ beam_outputs = beam_scorer.process(
457
+ input_ids,
458
+ next_token_scores,
459
+ next_tokens,
460
+ next_indices,
461
+ pad_token_id=pad_token_id,
462
+ eos_token_id=eos_token_id,
463
+ beam_indices=beam_indices,
464
+ )
465
+
466
+ beam_scores = beam_outputs["next_beam_scores"]
467
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
468
+ beam_idx = beam_outputs["next_beam_indices"]
469
+
470
+ input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
471
+
472
+ model_kwargs = self._update_model_kwargs_for_generation(
473
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
474
+ )
475
+ if model_kwargs["past_key_values"] is not None:
476
+ model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
477
+
478
+ if return_dict_in_generate and output_scores:
479
+ beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
480
+
481
+ # increase cur_len
482
+ cur_len = cur_len + 1
483
+
484
+ try:
485
+ if beam_scorer.is_done or stopping_criteria(input_ids, scores):
486
+ if not synced_gpus:
487
+ break
488
+ else:
489
+ this_peer_finished = True
490
+ except:
491
+ if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
492
+ if not synced_gpus:
493
+ break
494
+ else:
495
+ this_peer_finished = True
496
+
497
+
498
+ sequence_outputs = beam_scorer.finalize(
499
+ input_ids,
500
+ beam_scores,
501
+ next_tokens,
502
+ next_indices,
503
+ pad_token_id=pad_token_id,
504
+ eos_token_id=eos_token_id,
505
+ max_length=stopping_criteria.max_length,
506
+ beam_indices=beam_indices,
507
+ )
508
+
509
+ if return_dict_in_generate:
510
+ if not output_scores:
511
+ sequence_outputs["sequence_scores"] = None
512
+
513
+ if self.config.is_encoder_decoder:
514
+ return BeamSearchEncoderDecoderOutput(
515
+ sequences=sequence_outputs["sequences"],
516
+ sequences_scores=sequence_outputs["sequence_scores"],
517
+ scores=scores,
518
+ beam_indices=sequence_outputs["beam_indices"],
519
+ encoder_attentions=encoder_attentions,
520
+ encoder_hidden_states=encoder_hidden_states,
521
+ decoder_attentions=decoder_attentions,
522
+ cross_attentions=cross_attentions,
523
+ decoder_hidden_states=decoder_hidden_states,
524
+ )
525
+ else:
526
+ return BeamSearchDecoderOnlyOutput(
527
+ sequences=sequence_outputs["sequences"],
528
+ sequences_scores=sequence_outputs["sequence_scores"],
529
+ scores=scores,
530
+ beam_indices=sequence_outputs["beam_indices"],
531
+ attentions=decoder_attentions,
532
+ hidden_states=decoder_hidden_states,
533
+ )
534
+ else:
535
+ return sequence_outputs["sequences"]
536
+
537
+
538
+ class CABlock(nn.Module):
539
+ '''
540
+ This function is an edited version of the gpt2 decoder block function from HuggingFace's transformers
541
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
542
+ '''
543
+ def __init__(self, config, layer_idx=None):
544
+ super().__init__()
545
+ hidden_size = config.hidden_size
546
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
547
+
548
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
549
+
550
+ self.crossattention = GPT2Attention(config, is_cross_attention=True, layer_idx=layer_idx)
551
+ self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
552
+
553
+ self.mlp = GPT2MLP(inner_dim, config)
554
+
555
+ def forward(
556
+ self,
557
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
558
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
559
+ attention_mask: Optional[torch.FloatTensor] = None,
560
+ head_mask: Optional[torch.FloatTensor] = None,
561
+ encoder_hidden_states: Optional[torch.Tensor] = None,
562
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
563
+ use_cache: Optional[bool] = False,
564
+ output_attentions: Optional[bool] = False,
565
+ ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
566
+
567
+
568
+ residual = hidden_states
569
+ hidden_states = self.ln_cross_attn(hidden_states)
570
+ cross_attn_outputs = self.crossattention(
571
+ hidden_states,
572
+ attention_mask=attention_mask,
573
+ head_mask=head_mask,
574
+ encoder_hidden_states=encoder_hidden_states,
575
+ encoder_attention_mask=encoder_attention_mask,
576
+ output_attentions=output_attentions,
577
+ )
578
+ attn_output = cross_attn_outputs[0]
579
+ # residual connection
580
+ hidden_states = residual + attn_output
581
+
582
+ residual = hidden_states
583
+ hidden_states = self.ln_2(hidden_states)
584
+ feed_forward_hidden_states = self.mlp(hidden_states)
585
+ # residual connection
586
+ hidden_states = residual + feed_forward_hidden_states
587
+
588
+ return (hidden_states,)
589
+