Nanobit commited on
Commit
2a801b0
1 Parent(s): e44c9e0

Fix grad checkpoint and outputs param

Browse files
src/axolotl/monkeypatch/llama_landmark_attn.py CHANGED
@@ -27,7 +27,6 @@ from typing import List, Optional, Tuple, Union
27
 
28
  import torch
29
  import torch.utils.checkpoint
30
- import transformers
31
  from torch import nn
32
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
33
  from transformers.activations import ACT2FN
@@ -52,10 +51,6 @@ _CONFIG_FOR_DOC = "LlamaConfig"
52
  MEM_TOKEN = "<landmark>" # nosec
53
 
54
 
55
- def hijack_llama_landmark_attn():
56
- transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM
57
-
58
-
59
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
60
  def _make_causal_mask(
61
  input_ids_shape: torch.Size,
@@ -1125,7 +1120,7 @@ class LlamaModel(LlamaPreTrainedModel):
1125
  def create_custom_forward(module):
1126
  def custom_forward(*inputs):
1127
  # None for past_key_value
1128
- return module(*inputs, output_attentions, None)
1129
 
1130
  return custom_forward
1131
 
@@ -1135,6 +1130,8 @@ class LlamaModel(LlamaPreTrainedModel):
1135
  attention_mask,
1136
  position_ids,
1137
  None,
 
 
1138
  is_mem,
1139
  last_section_mask,
1140
  )
@@ -1300,7 +1297,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
1300
  return_dict=return_dict,
1301
  offload_cache_to_cpu=offload_cache_to_cpu,
1302
  )
1303
- past_key_values = outputs[1]
1304
  if last_logits is not None:
1305
  last_logits = torch.cat((last_logits, outputs[0]), dim=-2)
1306
  last_logits = outputs[0]
 
27
 
28
  import torch
29
  import torch.utils.checkpoint
 
30
  from torch import nn
31
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
32
  from transformers.activations import ACT2FN
 
51
  MEM_TOKEN = "<landmark>" # nosec
52
 
53
 
 
 
 
 
54
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
55
  def _make_causal_mask(
56
  input_ids_shape: torch.Size,
 
1120
  def create_custom_forward(module):
1121
  def custom_forward(*inputs):
1122
  # None for past_key_value
1123
+ return module(*inputs)
1124
 
1125
  return custom_forward
1126
 
 
1130
  attention_mask,
1131
  position_ids,
1132
  None,
1133
+ output_attentions,
1134
+ None,
1135
  is_mem,
1136
  last_section_mask,
1137
  )
 
1297
  return_dict=return_dict,
1298
  offload_cache_to_cpu=offload_cache_to_cpu,
1299
  )
1300
+ past_key_values = outputs.past_key_values
1301
  if last_logits is not None:
1302
  last_logits = torch.cat((last_logits, outputs[0]), dim=-2)
1303
  last_logits = outputs[0]