jaandoui commited on
Commit
b2f170e
1 Parent(s): db13928

Update bert_layers.py

Browse files
Files changed (1) hide show
  1. bert_layers.py +40 -40
bert_layers.py CHANGED
@@ -169,12 +169,12 @@ class BertUnpadSelfAttention(nn.Module):
169
  self.attention_head_size)
170
  attention_scores = attention_scores + bias
171
  attention_probs = nn.functional.softmax(attention_scores, dim=-1)
172
- print(f'BUSA: attention_probs 1 shape: {attention_probs.shape}')
173
  attention_probs = self.dropout(attention_probs)
174
- print(f'BUSA: attention_probs 2 shape: {attention_probs.shape}')
175
  attention = torch.matmul(attention_probs, v).permute(0, 2, 1,
176
  3) # b s h d
177
- print(f'BUSA: attention shape: {attention.shape}')
178
  else:
179
  # Triton implementation only supports 0 attention dropout
180
  convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
@@ -185,24 +185,24 @@ class BertUnpadSelfAttention(nn.Module):
185
  bias_dtype = bias.dtype
186
  bias = bias.to(torch.float16)
187
  attention = flash_attn_qkvpacked_func(qkv, bias)
188
- print(f'BUSA Triton: attention 0 shape: {attention_probs.shape}')
189
  attention = attention.to(orig_dtype)
190
- print(f'BUSA Triton: attention 1 shape: {attention_probs.shape}')
191
  bias = bias.to(bias_dtype)
192
  else:
193
  attention = flash_attn_qkvpacked_func(qkv, bias)
194
- print(f'BUSA Triton: attention 2 shape: {attention_probs.shape}')
195
  # attn_mask is 1 for attend and 0 for don't
196
  attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
197
- print(f'BUSA unpadded final attention shape: {attention_probs.shape}')
198
- print(f'ATTENTION: {attention.shape}')
199
 
200
- print(f'PROBLEM HERE: UNDERSTAND IT!!')
201
  rearranged_attention = rearrange(attention, 'nnz h d -> nnz (h d)')
202
  try:
203
- print(f'REARRANGED ATTENTION: {rearranged_attention.shape}')
204
  except:
205
- print(f'REARRANGED ATTENTION: {rearranged_attention[0].shape}')
206
  return rearrange(attention, 'nnz h d -> nnz (h d)'), attention_probs
207
 
208
 
@@ -257,10 +257,10 @@ class BertUnpadAttention(nn.Module):
257
  self_output, attention_probs = self.self(input_tensor, cu_seqlens, max_s, indices,
258
  attn_mask, bias)
259
 
260
- try:
261
- print(f'IMPORTANT: {self_output.shape}')
262
- except:
263
- print(f'IMPORTANT2: {self_output[0].shape}')
264
 
265
  if subset_idx is not None:
266
  return self.output(index_first_axis(self_output, subset_idx),
@@ -349,9 +349,9 @@ class BertLayer(nn.Module):
349
  """
350
  attention_output, attention_probs = self.attention(hidden_states, cu_seqlens, seqlen,
351
  subset_idx, indices, attn_mask, bias)
352
- print(f'BertLayer attention_output shape: {attention_output.shape}')
353
  layer_output = self.mlp(attention_output)
354
- print(f'BertLayer layer_output shape: {layer_output.shape}')
355
  return layer_output, attention_output, attention_probs # JAANDOUI: this only returns layer_output in the original work.
356
 
357
 
@@ -372,7 +372,7 @@ class BertEncoder(nn.Module):
372
  [copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
373
 
374
  self.num_attention_heads = config.num_attention_heads
375
- print(f'nbr of attention heads: {self.num_attention_heads}')
376
  # The alibi mask will be dynamically expanded if it is too small for
377
  # the input the model receives. But it generally helps to initialize it
378
  # to a reasonably large size to help pre-allocate CUDA memory.
@@ -481,7 +481,7 @@ class BertEncoder(nn.Module):
481
  bias=alibi_attn_mask)
482
  # JAANDOUI
483
  # print(f'Inner Attention: {attention_weights}')
484
- print(f'Inner Attention shape: {attention_weights.shape}')
485
  all_attention_weights.append(attention_weights) # Store attention weights
486
  all_attention_probs.append(attention_probs) # Store attention probs
487
 
@@ -523,7 +523,7 @@ class BertEncoder(nn.Module):
523
  all_attention_probs.append(attention_probs) # Store attention probs
524
 
525
  # print(f'here is the matrix of attentions inside encoder: \n {all_attention_weights}')
526
- print(f'and this is the [0]shape inside encoder: \n {all_attention_weights[0].shape}')
527
  # print(f'NUMBER6: {all_attention_weights}')
528
  if not output_all_encoded_layers:
529
  all_encoder_layers.append(hidden_states)
@@ -663,7 +663,7 @@ class BertModel(BertPreTrainedModel):
663
  subset_mask=subset_mask)
664
  # print(f'NUMBER7: {all_attention_weights}')
665
  # print(f'here is the matrix of attentions in BERT: \n {all_attention_weights}')
666
- print(f'and this is the [0]shape in BERT: \n {all_attention_weights[0].shape}')
667
 
668
  if masked_tokens_mask is None:
669
  sequence_output = encoder_outputs[-1]
@@ -930,28 +930,28 @@ class BertForSequenceClassification(BertPreTrainedModel):
930
 
931
  pooled_output = outputs[1]
932
 
933
- try:
934
- print(f'outputs[2] before reassignment SHAPE: {outputs[3][0].shape} ')
935
- except:
936
- print(print(f'outputs[2] before reassignment LENGTH: {len(outputs[3][0])} '))
937
 
938
  # JAANDOUI:
939
  all_attention_probs = outputs[3]
940
 
941
- try:
942
- print(f'outputs[2] AFTER reassignment probsss SHAPE: {outputs[3][0].shape} ')
943
- except:
944
- print(print(f'outputs[2] AFTER reassignment probsss LENGTH: {len(outputs[3][0])} '))
945
 
946
 
947
 
948
- try:
949
- print(f'all_attention_weights probsss last: {all_attention_probs.shape}')
950
- except:
951
- try:
952
- print(f'last first except probsss: {all_attention_probs[0].shape}')
953
- except:
954
- print(f'last second except probsss: {len(all_attention_probs[0])}')
955
 
956
 
957
  pooled_output = self.dropout(pooled_output)
@@ -990,10 +990,10 @@ class BertForSequenceClassification(BertPreTrainedModel):
990
  return ((loss,) + output) if loss is not None else output
991
 
992
  # print(outputs.attentions)
993
- try:
994
- print(f'not stacked final attention probsss SHAPE: {outputs[3][0].shape}')
995
- except:
996
- print(f'not stacked final attention probsss LEN: {len(outputs[3])}')
997
 
998
  # try:
999
  # print(f'STACKED final attention SHAPE: {(outputs.attentions).shape}')
 
169
  self.attention_head_size)
170
  attention_scores = attention_scores + bias
171
  attention_probs = nn.functional.softmax(attention_scores, dim=-1)
172
+ # print(f'BUSA: attention_probs 1 shape: {attention_probs.shape}')
173
  attention_probs = self.dropout(attention_probs)
174
+ # print(f'BUSA: attention_probs 2 shape: {attention_probs.shape}')
175
  attention = torch.matmul(attention_probs, v).permute(0, 2, 1,
176
  3) # b s h d
177
+ # print(f'BUSA: attention shape: {attention.shape}')
178
  else:
179
  # Triton implementation only supports 0 attention dropout
180
  convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
 
185
  bias_dtype = bias.dtype
186
  bias = bias.to(torch.float16)
187
  attention = flash_attn_qkvpacked_func(qkv, bias)
188
+ # print(f'BUSA Triton: attention 0 shape: {attention_probs.shape}')
189
  attention = attention.to(orig_dtype)
190
+ # print(f'BUSA Triton: attention 1 shape: {attention_probs.shape}')
191
  bias = bias.to(bias_dtype)
192
  else:
193
  attention = flash_attn_qkvpacked_func(qkv, bias)
194
+ # print(f'BUSA Triton: attention 2 shape: {attention_probs.shape}')
195
  # attn_mask is 1 for attend and 0 for don't
196
  attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
197
+ # print(f'BUSA unpadded final attention shape: {attention_probs.shape}')
198
+ # print(f'ATTENTION: {attention.shape}')
199
 
200
+ # print(f'PROBLEM HERE: UNDERSTAND IT!!')
201
  rearranged_attention = rearrange(attention, 'nnz h d -> nnz (h d)')
202
  try:
203
+ # print(f'REARRANGED ATTENTION: {rearranged_attention.shape}')
204
  except:
205
+ # print(f'REARRANGED ATTENTION: {rearranged_attention[0].shape}')
206
  return rearrange(attention, 'nnz h d -> nnz (h d)'), attention_probs
207
 
208
 
 
257
  self_output, attention_probs = self.self(input_tensor, cu_seqlens, max_s, indices,
258
  attn_mask, bias)
259
 
260
+ # try:
261
+ # print(f'IMPORTANT: {self_output.shape}')
262
+ # except:
263
+ # print(f'IMPORTANT2: {self_output[0].shape}')
264
 
265
  if subset_idx is not None:
266
  return self.output(index_first_axis(self_output, subset_idx),
 
349
  """
350
  attention_output, attention_probs = self.attention(hidden_states, cu_seqlens, seqlen,
351
  subset_idx, indices, attn_mask, bias)
352
+ # print(f'BertLayer attention_output shape: {attention_output.shape}')
353
  layer_output = self.mlp(attention_output)
354
+ # print(f'BertLayer layer_output shape: {layer_output.shape}')
355
  return layer_output, attention_output, attention_probs # JAANDOUI: this only returns layer_output in the original work.
356
 
357
 
 
372
  [copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
373
 
374
  self.num_attention_heads = config.num_attention_heads
375
+ # print(f'nbr of attention heads: {self.num_attention_heads}')
376
  # The alibi mask will be dynamically expanded if it is too small for
377
  # the input the model receives. But it generally helps to initialize it
378
  # to a reasonably large size to help pre-allocate CUDA memory.
 
481
  bias=alibi_attn_mask)
482
  # JAANDOUI
483
  # print(f'Inner Attention: {attention_weights}')
484
+ # print(f'Inner Attention shape: {attention_weights.shape}')
485
  all_attention_weights.append(attention_weights) # Store attention weights
486
  all_attention_probs.append(attention_probs) # Store attention probs
487
 
 
523
  all_attention_probs.append(attention_probs) # Store attention probs
524
 
525
  # print(f'here is the matrix of attentions inside encoder: \n {all_attention_weights}')
526
+ # print(f'and this is the [0]shape inside encoder: \n {all_attention_weights[0].shape}')
527
  # print(f'NUMBER6: {all_attention_weights}')
528
  if not output_all_encoded_layers:
529
  all_encoder_layers.append(hidden_states)
 
663
  subset_mask=subset_mask)
664
  # print(f'NUMBER7: {all_attention_weights}')
665
  # print(f'here is the matrix of attentions in BERT: \n {all_attention_weights}')
666
+ # print(f'and this is the [0]shape in BERT: \n {all_attention_weights[0].shape}')
667
 
668
  if masked_tokens_mask is None:
669
  sequence_output = encoder_outputs[-1]
 
930
 
931
  pooled_output = outputs[1]
932
 
933
+ # try:
934
+ # print(f'outputs[2] before reassignment SHAPE: {outputs[3][0].shape} ')
935
+ # except:
936
+ # print(print(f'outputs[2] before reassignment LENGTH: {len(outputs[3][0])} '))
937
 
938
  # JAANDOUI:
939
  all_attention_probs = outputs[3]
940
 
941
+ # try:
942
+ # print(f'outputs[2] AFTER reassignment probsss SHAPE: {outputs[3][0].shape} ')
943
+ # except:
944
+ # print(print(f'outputs[2] AFTER reassignment probsss LENGTH: {len(outputs[3][0])} '))
945
 
946
 
947
 
948
+ # try:
949
+ # print(f'all_attention_weights probsss last: {all_attention_probs.shape}')
950
+ # except:
951
+ # try:
952
+ # print(f'last first except probsss: {all_attention_probs[0].shape}')
953
+ # except:
954
+ # print(f'last second except probsss: {len(all_attention_probs[0])}')
955
 
956
 
957
  pooled_output = self.dropout(pooled_output)
 
990
  return ((loss,) + output) if loss is not None else output
991
 
992
  # print(outputs.attentions)
993
+ # try:
994
+ # print(f'not stacked final attention probsss SHAPE: {outputs[3][0].shape}')
995
+ # except:
996
+ # print(f'not stacked final attention probsss LEN: {len(outputs[3])}')
997
 
998
  # try:
999
  # print(f'STACKED final attention SHAPE: {(outputs.attentions).shape}')