winglian commited on
Commit
12a2dbb
1 Parent(s): 3a2edc8

Support Sample packing for phi arch (#586)

Browse files

* phi sequence packing

* sample packing fixes

* fix linting

* fix inference and phi e2e tests

* update phi example now that sample packing works

* wandb import keeps getting moved around

.mypy.ini CHANGED
@@ -8,6 +8,9 @@ ignore_missing_imports = True
8
  [mypy-axolotl.monkeypatch.*]
9
  ignore_errors = True
10
 
 
 
 
11
  [mypy-flash_attn.*]
12
  ignore_missing_imports = True
13
 
@@ -20,6 +23,9 @@ ignore_missing_imports = True
20
  [mypy-peft]
21
  ignore_missing_imports = True
22
 
 
 
 
23
  [mypy-bitsandbytes]
24
  ignore_missing_imports = True
25
 
 
8
  [mypy-axolotl.monkeypatch.*]
9
  ignore_errors = True
10
 
11
+ [mypy-axolotl.models.phi.*]
12
+ ignore_errors = True
13
+
14
  [mypy-flash_attn.*]
15
  ignore_missing_imports = True
16
 
 
23
  [mypy-peft]
24
  ignore_missing_imports = True
25
 
26
+ [mypy-wandb]
27
+ ignore_missing_imports = True
28
+
29
  [mypy-bitsandbytes]
30
  ignore_missing_imports = True
31
 
examples/phi/phi-ft.yml CHANGED
@@ -1,6 +1,6 @@
1
  base_model: microsoft/phi-1_5
2
  base_model_config: microsoft/phi-1_5
3
- model_type: AutoModelForCausalLM
4
  tokenizer_type: AutoTokenizer
5
  is_llama_derived_model: false
6
  trust_remote_code: true
@@ -18,7 +18,7 @@ val_set_size: 0.05
18
  output_dir: ./phi-sft-out
19
 
20
  sequence_len: 2048
21
- sample_packing: false # does not work with phi
22
  pad_to_sequence_len:
23
 
24
  adapter:
@@ -35,10 +35,10 @@ wandb_watch:
35
  wandb_run_id:
36
  wandb_log_model:
37
 
38
- gradient_accumulation_steps: 2
39
  micro_batch_size: 1
40
  num_epochs: 4
41
- optimizer: adamw_bnb_8bit
42
  adam_beta2: 0.95
43
  adam_epsilon: 0.00001
44
  max_grad_norm: 1.0
 
1
  base_model: microsoft/phi-1_5
2
  base_model_config: microsoft/phi-1_5
3
+ model_type: MixFormerSequentialForCausalLM
4
  tokenizer_type: AutoTokenizer
5
  is_llama_derived_model: false
6
  trust_remote_code: true
 
18
  output_dir: ./phi-sft-out
19
 
20
  sequence_len: 2048
21
+ sample_packing: true
22
  pad_to_sequence_len:
23
 
24
  adapter:
 
35
  wandb_run_id:
36
  wandb_log_model:
37
 
38
+ gradient_accumulation_steps: 1
39
  micro_batch_size: 1
40
  num_epochs: 4
41
+ optimizer: adamw_torch
42
  adam_beta2: 0.95
43
  adam_epsilon: 0.00001
44
  max_grad_norm: 1.0
src/axolotl/models/__init__.py ADDED
File without changes
src/axolotl/models/phi/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """
2
+ MixFormers model architecture used for phi models
3
+ """
4
+
5
+ from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa
6
+ from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa
src/axolotl/models/phi/configuration_mixformer_sequential.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # Licensed under the MIT license.
5
+
6
+ import math
7
+ from typing import Any, Dict, List, Optional, Union
8
+
9
+ from transformers import PretrainedConfig
10
+
11
+
12
+ class MixFormerSequentialConfig(PretrainedConfig):
13
+ """MixFormer (sequential for DeepSpeed) configuration."""
14
+
15
+ model_type = "mixformer-sequential"
16
+
17
+ attribute_map = {
18
+ "max_position_embeddings": "n_positions",
19
+ "hidden_size": "n_embd",
20
+ "num_attention_heads": "n_head",
21
+ "num_hidden_layers": "n_layer",
22
+ "input_emb_layer": "embd_layer", # `input_emb_layer` key is for backward compatibility
23
+ "blocks": "architecture", # `blocks` key is for backward compatibility
24
+ }
25
+
26
+ def __init__(
27
+ self,
28
+ vocab_size: Optional[int] = 50304,
29
+ n_positions: Optional[int] = 2048,
30
+ n_embd: Optional[int] = 1024,
31
+ n_layer: Optional[int] = 20,
32
+ n_inner: Optional[int] = None,
33
+ n_head: Optional[int] = 16,
34
+ rotary_dim: Optional[int] = 32,
35
+ activation_function: Optional[str] = "gelu_new",
36
+ embd_layer: Optional[str] = "default",
37
+ architecture: Union[Dict[str, Any], List[Dict[str, Any]]] = None,
38
+ embd_pdrop: Optional[float] = 0.0,
39
+ resid_pdrop: Optional[float] = 0.0,
40
+ layer_norm_epsilon: Optional[float] = 1e-5,
41
+ initializer_range: Optional[float] = 0.02,
42
+ tie_word_embeddings: Optional[bool] = False,
43
+ pad_vocab_size_multiple: Optional[int] = 64,
44
+ **kwargs
45
+ ) -> None:
46
+ self.vocab_size = int(
47
+ math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
48
+ )
49
+ self.n_positions = n_positions
50
+ self.n_embd = n_embd
51
+ self.n_layer = n_layer
52
+ self.n_inner = n_inner
53
+ self.n_head = n_head
54
+ self.rotary_dim = min(rotary_dim, n_embd // n_head)
55
+ self.activation_function = activation_function
56
+ self.embd_layer = embd_layer
57
+ self.architecture = architecture
58
+ self.embd_pdrop = embd_pdrop
59
+ self.resid_pdrop = resid_pdrop
60
+ self.layer_norm_epsilon = layer_norm_epsilon
61
+ self.initializer_range = initializer_range
62
+
63
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
src/axolotl/models/phi/modeling_mixformer_sequential.py ADDED
@@ -0,0 +1,934 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # Licensed under the MIT license.
5
+
6
+ # BSD 3-Clause License
7
+ #
8
+ # Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
9
+ # All rights reserved.
10
+ #
11
+ # Redistribution and use in source and binary forms, with or without
12
+ # modification, are permitted provided that the following conditions are met:
13
+ #
14
+ # * Redistributions of source code must retain the above copyright notice, this
15
+ # list of conditions and the following disclaimer.
16
+ #
17
+ # * Redistributions in binary form must reproduce the above copyright notice,
18
+ # this list of conditions and the following disclaimer in the documentation
19
+ # and/or other materials provided with the distribution.
20
+ #
21
+ # * Neither the name of the copyright holder nor the names of its
22
+ # contributors may be used to endorse or promote products derived from
23
+ # this software without specific prior written permission.
24
+ #
25
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
28
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
29
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
30
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
31
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
32
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
33
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
34
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
35
+
36
+ from __future__ import annotations
37
+
38
+ import copy
39
+ import inspect
40
+ from dataclasses import dataclass, field
41
+ from typing import Any, Dict, Optional, Tuple
42
+
43
+ import torch
44
+ import torch.nn as nn
45
+ from einops import rearrange
46
+ from flash_attn.flash_attn_interface import (
47
+ flash_attn_kvpacked_func,
48
+ flash_attn_qkvpacked_func,
49
+ flash_attn_varlen_qkvpacked_func,
50
+ )
51
+ from transformers import PretrainedConfig, PreTrainedModel
52
+ from transformers.activations import ACT2FN
53
+ from transformers.modeling_outputs import CausalLMOutputWithPast
54
+
55
+ from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids
56
+ from .configuration_mixformer_sequential import MixFormerSequentialConfig
57
+
58
+
59
+ @dataclass
60
+ class InferenceParams:
61
+ """Inference parameters that are passed to the main model in order
62
+ to efficienly calculate and store the context during inference.
63
+ Adapted from https://github.com/Dao-AILab/flash-attention."""
64
+
65
+ max_sequence_len: int
66
+ max_batch_size: int
67
+ sequence_len_offset: int = 0
68
+ batch_size_offset: int = 0
69
+ key_value_memory_dict: dict = field(default_factory=dict)
70
+ fused_ft_kernel: bool = False
71
+ lengths_per_sample: Optional[torch.Tensor] = None
72
+
73
+
74
+ class Embedding(nn.Module):
75
+ """Token embedding with dropout."""
76
+
77
+ def __init__(self, config: PretrainedConfig) -> None:
78
+ super().__init__()
79
+
80
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
81
+ self.drop = nn.Dropout(config.embd_pdrop)
82
+
83
+ def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
84
+ input_shape = input_ids.size()
85
+ input_ids = input_ids.view(-1, input_shape[-1])
86
+
87
+ hidden_states = self.wte(input_ids)
88
+ hidden_states = self.drop(hidden_states)
89
+
90
+ return hidden_states
91
+
92
+
93
+ class RotaryEmbedding(nn.Module):
94
+ """PyTorch implementation of `flash-attn` RotaryEmbedding layer.
95
+ Adapted from https://github.com/Dao-AILab/flash-attention."""
96
+
97
+ def __init__(
98
+ self,
99
+ dim: int,
100
+ base: Optional[int] = 10000,
101
+ scale_base: Optional[float] = None,
102
+ device: Optional[str] = None,
103
+ **kwargs,
104
+ ) -> None:
105
+ super().__init__()
106
+
107
+ if scale_base is not None:
108
+ raise NotImplementedError
109
+
110
+ # Generate and save the inverse frequency buffer (non-trainable)
111
+ self.dim = dim
112
+ self.base = base
113
+ self.scale_base = scale_base
114
+ self.device = device
115
+
116
+ inv_freq = 1.0 / (
117
+ base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
118
+ )
119
+ self.register_buffer("inv_freq", inv_freq)
120
+
121
+ scale = (
122
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
123
+ / (1.4 * dim)
124
+ if scale_base is not None
125
+ else None
126
+ )
127
+ self.register_buffer("scale", scale)
128
+
129
+ self._seq_len_cached = 0
130
+ self._cos_cached = None
131
+ self._sin_cached = None
132
+ self._cos_k_cached = None
133
+ self._sin_k_cached = None
134
+
135
+ def _update_cos_sin_cache(
136
+ self, x: torch.FloatTensor, seqlen_offset: Optional[int] = 0
137
+ ) -> None:
138
+ # Reset the tables if the sequence length has changed,
139
+ # or if we're on a new device (possibly due to tracing for instance)
140
+ seqlen = x.shape[1] + seqlen_offset
141
+
142
+ # Re-generate the inverse frequency buffer if it's not fp32
143
+ # (for instance if model.half() was called)
144
+ if self.inv_freq.dtype != "torch.float32":
145
+ self.inv_freq = 1.0 / (
146
+ self.base
147
+ ** (
148
+ torch.arange(
149
+ 0, self.dim, 2, device=self.device, dtype=torch.float32
150
+ )
151
+ / self.dim
152
+ )
153
+ )
154
+
155
+ if (
156
+ seqlen > self._seq_len_cached
157
+ or self._cos_cached.device != x.device
158
+ or self._cos_cached.dtype != x.dtype
159
+ ):
160
+ self._seq_len_cached = seqlen
161
+ t = torch.arange(seqlen, device=x.device, dtype=torch.float32)
162
+
163
+ # Don't do einsum, it converts fp32 to fp16
164
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
165
+ freqs = torch.outer(
166
+ t, self.inv_freq.to(device=t.device, dtype=torch.float32)
167
+ )
168
+ if self.scale is None:
169
+ self._cos_cached = torch.cos(freqs).to(x.dtype)
170
+ self._sin_cached = torch.sin(freqs).to(x.dtype)
171
+ else:
172
+ power = (
173
+ torch.arange(
174
+ seqlen, dtype=self.scale.dtype, device=self.scale.device
175
+ )
176
+ - seqlen // 2
177
+ ) / self.scale_base
178
+ scale = self.scale.to(device=power.device) ** rearrange(
179
+ power, "s -> s 1"
180
+ )
181
+
182
+ # We want the multiplication by scale to happen in fp32
183
+ self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
184
+ self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
185
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
186
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
187
+
188
+ def apply_rotary_emb_qkv(
189
+ self,
190
+ qkv: torch.FloatTensor,
191
+ sin: torch.FloatTensor,
192
+ cos: torch.FloatTensor,
193
+ sin_k: Optional[torch.FloatTensor] = None,
194
+ cos_k: Optional[torch.FloatTensor] = None,
195
+ ) -> torch.FloatTensor:
196
+ _, seqlen, three, _, headdim = qkv.shape
197
+ assert three == 3
198
+
199
+ rotary_seqlen, rotary_dim = cos.shape
200
+ rotary_dim *= 2
201
+ assert rotary_dim <= headdim
202
+ assert seqlen <= rotary_seqlen
203
+
204
+ cos_k = cos if cos_k is None else cos_k
205
+ sin_k = sin if sin_k is None else sin_k
206
+ assert (
207
+ sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
208
+ )
209
+
210
+ q_rot = qkv[:, :, 0, :, :rotary_dim]
211
+ q_pass = qkv[:, :, 0, :, rotary_dim:]
212
+
213
+ k_rot = qkv[:, :, 1, :, :rotary_dim]
214
+ k_pass = qkv[:, :, 1, :, rotary_dim:]
215
+
216
+ # Splits the queries and keys in half
217
+ q1, q2 = q_rot.chunk(2, dim=-1)
218
+ k1, k2 = k_rot.chunk(2, dim=-1)
219
+ c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(
220
+ sin[:seqlen], "s d -> s 1 d"
221
+ )
222
+
223
+ # Casts to fp32 are necessary to prevent fp16 overflow issues
224
+ q1, q2, k1, k2, c, s = [
225
+ t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]
226
+ ]
227
+
228
+ # Computes the new keys and queries, recasting to original dtype
229
+ q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
230
+
231
+ k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
232
+
233
+ return torch.cat(
234
+ [
235
+ torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
236
+ torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
237
+ qkv[:, :, 2:3, :, :],
238
+ ],
239
+ axis=2,
240
+ )
241
+
242
+ def forward(
243
+ self, qkv: torch.Tensor, seqlen_offset: int = 0
244
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
245
+ """Perform the forward pass.
246
+
247
+ Args:
248
+ qkv: Query, key and value tensors of shape (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim).
249
+ seqlen_offset: Used in generation where the passed `qkv` is only the last token in the batch.
250
+
251
+ Returns:
252
+ New `qkv` and the cached sinusoids.
253
+
254
+ """
255
+
256
+ self._update_cos_sin_cache(qkv, seqlen_offset)
257
+
258
+ return self.apply_rotary_emb_qkv(
259
+ qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:]
260
+ )
261
+
262
+
263
+ def _update_kv_cache(kv, inference_params, layer_idx):
264
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
265
+ Adapted from https://github.com/Dao-AILab/flash-attention."""
266
+ # Pre-allocate memory for key-values for inference.
267
+ num_heads, head_dim = kv.shape[-2:]
268
+ if layer_idx not in inference_params.key_value_memory_dict:
269
+ kv_cache = torch.empty(
270
+ inference_params.max_batch_size,
271
+ inference_params.max_sequence_len,
272
+ 2,
273
+ num_heads,
274
+ head_dim,
275
+ dtype=kv.dtype,
276
+ device=kv.device,
277
+ )
278
+ inference_params.key_value_memory_dict[layer_idx] = kv_cache
279
+ else:
280
+ kv_cache = inference_params.key_value_memory_dict[layer_idx]
281
+
282
+ # Adjust key and value for inference
283
+ batch_start = inference_params.batch_size_offset
284
+ batch_end = batch_start + kv.shape[0]
285
+ sequence_start = inference_params.sequence_len_offset
286
+ sequence_end = sequence_start + kv.shape[1]
287
+ assert batch_end <= (
288
+ kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0] # noqa
289
+ )
290
+ assert sequence_end <= (
291
+ kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2] # noqa
292
+ )
293
+
294
+ assert kv_cache is not None
295
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
296
+ kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
297
+ return kv
298
+
299
+
300
+ class MLP(nn.Module):
301
+ """Multi-Layer Perceptron.
302
+
303
+ Reference:
304
+ Attention Is All You Need.
305
+ https://arxiv.org/pdf/1706.03762.pdf.
306
+
307
+ """
308
+
309
+ def __init__(
310
+ self,
311
+ config: PretrainedConfig,
312
+ n_inner: Optional[int] = None,
313
+ act_fn: Optional[str] = None,
314
+ ) -> None:
315
+ super().__init__()
316
+
317
+ act_fn = config.activation_function if act_fn is None else act_fn
318
+ assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
319
+
320
+ n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
321
+ n_inner = n_inner if n_inner is not None else 4 * config.n_embd
322
+
323
+ self.fc1 = nn.Linear(config.n_embd, n_inner)
324
+ self.fc2 = nn.Linear(n_inner, config.n_embd)
325
+ self.act = ACT2FN[act_fn]
326
+
327
+ def _load_from_state_dict(
328
+ self,
329
+ state_dict,
330
+ prefix,
331
+ local_metadata,
332
+ strict,
333
+ missing_keys,
334
+ unexpected_keys,
335
+ error_msgs,
336
+ ):
337
+ old_keys = [
338
+ prefix + "fc_in.weight",
339
+ prefix + "fc_out.weight",
340
+ prefix + "fc_in.bias",
341
+ prefix + "fc_out.bias",
342
+ ]
343
+ new_keys = [
344
+ prefix + "fc1.weight",
345
+ prefix + "fc2.weight",
346
+ prefix + "fc1.bias",
347
+ prefix + "fc2.bias",
348
+ ]
349
+
350
+ if all(k in state_dict for k in old_keys) and not all(
351
+ k in state_dict for k in new_keys
352
+ ):
353
+ # Older version of `MLP` saved with different key names.
354
+ for old_key, new_key in zip(old_keys, new_keys):
355
+ state_dict[new_key] = state_dict.pop(old_key)
356
+
357
+ return super()._load_from_state_dict(
358
+ state_dict,
359
+ prefix,
360
+ local_metadata,
361
+ strict,
362
+ missing_keys,
363
+ unexpected_keys,
364
+ error_msgs,
365
+ )
366
+
367
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
368
+ hidden_states = self.fc1(hidden_states)
369
+ hidden_states = self.act(hidden_states)
370
+ hidden_states = self.fc2(hidden_states)
371
+
372
+ return hidden_states
373
+
374
+
375
+ class FusedMLP(nn.Module):
376
+ """Fused Multi-Layer Perceptron from `flash-attn`.
377
+
378
+ Reference:
379
+ https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/ops/fused_dense.py.
380
+
381
+ """
382
+
383
+ def __init__(
384
+ self,
385
+ config: PretrainedConfig,
386
+ n_inner: Optional[int] = None,
387
+ act_fn: Optional[str] = None,
388
+ raise_on_missing: bool = False,
389
+ ) -> None:
390
+ super().__init__()
391
+
392
+ act_fn = config.activation_function if act_fn is None else act_fn
393
+ assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
394
+
395
+ n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
396
+ n_inner = n_inner if n_inner is not None else 4 * config.n_embd
397
+
398
+ gelu_activations = ["gelu_new", "gelu_fast", "gelu_approx"] # noqa
399
+ activation = "gelu_approx" if act_fn in gelu_activations else "relu" # noqa
400
+
401
+ self.mlp = MLP(config, n_inner=n_inner, act_fn=act_fn)
402
+
403
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
404
+ return self.mlp(hidden_states)
405
+
406
+
407
+ class SelfAttention(nn.Module):
408
+ """Implement the scaled dot product attention with softmax.
409
+ Adapted from https://github.com/Dao-AILab/flash-attention.
410
+ Arguments
411
+ ---------
412
+ softmax_scale: The temperature to use for the softmax attention.
413
+ (default: 1/sqrt(d_keys) where d_keys is computed at
414
+ runtime)
415
+ attention_dropout: The dropout rate to apply to the attention
416
+ (default: 0.0)
417
+ """
418
+
419
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
420
+ super().__init__()
421
+ self.causal = causal
422
+ self.softmax_scale = softmax_scale
423
+ self.drop = nn.Dropout(attention_dropout)
424
+
425
+ def forward(
426
+ self, qkv, causal=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None
427
+ ):
428
+ """Implements the multihead softmax attention.
429
+ Arguments
430
+ ---------
431
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
432
+ causal: if passed, will override self.causal
433
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
434
+ False means to mask out. (B, S)
435
+ """
436
+ causal = self.causal if causal is None else causal
437
+ if cu_seqlens is not None:
438
+ return flash_attn_varlen_qkvpacked_func(
439
+ qkv.squeeze(0),
440
+ cu_seqlens,
441
+ max_seqlen,
442
+ dropout_p=self.drop.p,
443
+ softmax_scale=self.softmax_scale,
444
+ causal=causal,
445
+ )
446
+ else:
447
+ return flash_attn_qkvpacked_func(
448
+ qkv,
449
+ dropout_p=self.drop.p,
450
+ softmax_scale=self.softmax_scale,
451
+ causal=causal,
452
+ )
453
+
454
+
455
+ class CrossAttention(nn.Module):
456
+ """Implement the scaled dot product attention with softmax.
457
+ Adapted from https://github.com/Dao-AILab/flash-attention.
458
+ Arguments
459
+ ---------
460
+ softmax_scale: The temperature to use for the softmax attention.
461
+ (default: 1/sqrt(d_keys) where d_keys is computed at
462
+ runtime)
463
+ attention_dropout: The dropout rate to apply to the attention
464
+ (default: 0.0)
465
+ """
466
+
467
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
468
+ super().__init__()
469
+ self.causal = causal
470
+ self.softmax_scale = softmax_scale
471
+ self.drop = nn.Dropout(attention_dropout)
472
+
473
+ def forward(self, q, kv, causal=None, key_padding_mask=None):
474
+ """Implements the multihead softmax attention.
475
+ Arguments
476
+ ---------
477
+ q: The tensor containing the query. (B, Sq, H, D)
478
+ kv: The tensor containing the key and value. (B, Sk, 2, H, D)
479
+ causal: if passed, will override self.causal
480
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
481
+ False means to mask out. (B, Sk)
482
+ """
483
+ causal = self.causal if causal is None else causal
484
+ return flash_attn_kvpacked_func(
485
+ q,
486
+ kv,
487
+ dropout_p=self.drop.p,
488
+ softmax_scale=self.softmax_scale,
489
+ causal=causal,
490
+ )
491
+
492
+
493
+ def find_mha_dims(
494
+ config: PretrainedConfig,
495
+ n_head: Optional[int] = None,
496
+ head_dim: Optional[int] = None,
497
+ ) -> Tuple[int, int]:
498
+ """Validate and return the number of heads and head dimension for multi-head attention.
499
+
500
+ Args:
501
+ config: Model configuration.
502
+ n_head: Number of heads.
503
+ head_dim: Head dimension.
504
+
505
+ Returns:
506
+ Number of heads and head dimension.
507
+
508
+ """
509
+
510
+ assert all(
511
+ hasattr(config, attr) for attr in ["n_embd", "n_head"]
512
+ ), "`config` must have `n_embd` and `n_head` attributes."
513
+
514
+ if head_dim is None:
515
+ assert (
516
+ config.n_embd % config.n_head == 0
517
+ ), f"Hidden size ({config.n_embd}) must be divisible by the number of heads ({config.n_head})."
518
+
519
+ if n_head is None and head_dim is None:
520
+ head_dim = config.n_embd // config.n_head
521
+ n_head = config.n_head
522
+ elif n_head is None or head_dim is None:
523
+ raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
524
+
525
+ return n_head, head_dim
526
+
527
+
528
+ class MHA(nn.Module):
529
+ """Multi-head attention layer.
530
+ Adapted from https://github.com/Dao-AILab/flash-attention."""
531
+
532
+ def __init__(
533
+ self,
534
+ config: PretrainedConfig,
535
+ rotary_dim: Optional[int] = None,
536
+ n_head: Optional[int] = None,
537
+ head_dim: Optional[int] = None,
538
+ bias: Optional[bool] = True,
539
+ dropout: Optional[float] = 0.0,
540
+ softmax_scale: Optional[float] = None,
541
+ causal: Optional[bool] = True,
542
+ layer_idx: Optional[int] = None,
543
+ rotary_emb_scale_base: Optional[float] = None,
544
+ return_residual: Optional[bool] = False,
545
+ checkpointing: Optional[bool] = False,
546
+ device: Optional[str] = None,
547
+ dtype: Optional[torch.dtype] = None,
548
+ fused_dense: Optional[bool] = True,
549
+ flash_attn: Optional[bool] = True,
550
+ cutlass_attn: Optional[bool] = False,
551
+ flash_rotary: Optional[bool] = True,
552
+ raise_on_missing: Optional[bool] = False,
553
+ ) -> None:
554
+ super().__init__()
555
+
556
+ factory_kwargs = {"device": device, "dtype": dtype}
557
+ n_head, head_dim = find_mha_dims(config, n_head, head_dim)
558
+
559
+ self.hidden_size = config.n_embd
560
+ self.n_head = n_head
561
+ self.head_dim = head_dim
562
+ self.op_size = n_head * head_dim
563
+
564
+ self.causal = causal
565
+ self.layer_idx = layer_idx
566
+ self.rotary_emb_dim = (
567
+ rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
568
+ )
569
+ self.fused_dense = fused_dense
570
+ self.flash_attn = flash_attn
571
+ self.cutlass_attn = cutlass_attn
572
+ self.flash_rotary = flash_rotary
573
+ self.return_residual = return_residual
574
+ self.checkpointing = checkpointing
575
+
576
+ if self.rotary_emb_dim > 0:
577
+ rotary_kwargs = {"device": device}
578
+ if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
579
+ rotary_kwargs["scale_base"] = rotary_emb_scale_base
580
+
581
+ self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs)
582
+ else:
583
+ pass
584
+
585
+ self.Wqkv = nn.Linear(
586
+ self.hidden_size, 3 * self.op_size, bias=bias, **factory_kwargs
587
+ )
588
+ self.out_proj = nn.Linear(
589
+ self.op_size, self.hidden_size, bias=bias, **factory_kwargs
590
+ )
591
+
592
+ self.inner_attn = SelfAttention(
593
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
594
+ )
595
+ self.inner_cross_attn = CrossAttention(
596
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
597
+ )
598
+
599
+ def _update_kv_cache(
600
+ self, kv: torch.FloatTensor, inference_params: InferenceParams
601
+ ) -> None:
602
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
603
+ Adapted from https://github.com/Dao-AILab/flash-attention."""
604
+
605
+ assert (
606
+ self.layer_idx is not None
607
+ ), "Generation requires layer_idx in the constructor"
608
+
609
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
610
+
611
+ def forward(
612
+ self,
613
+ x: torch.FloatTensor,
614
+ x_kv: Optional[torch.FloatTensor] = None,
615
+ key_padding_mask: Optional[torch.BoolTensor] = None,
616
+ cu_seqlens: Optional[torch.LongTensor] = None,
617
+ max_seqlen: Optional[int] = None,
618
+ mixer_subset: Optional[torch.LongTensor] = None,
619
+ past_cache: Optional[InferenceParams] = None,
620
+ **kwargs,
621
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
622
+ """Perform the forward pass.
623
+
624
+ Args:
625
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
626
+ cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
627
+ is the is the sum of the sequence lengths in the batch.
628
+ x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
629
+ key_padding_mask: boolean mask, True means to keep, False means to mask out.
630
+ (batch, seqlen). Only applicable when not using FlashAttention.
631
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
632
+ of the sequences in the batch, used to index into x. Only applicable when using
633
+ FlashAttention.
634
+ max_seqlen: int. Maximum sequence length in the batch.
635
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
636
+ before applying the query projection. Useful for e.g., ViT where we only care
637
+ about the CLS token in the last layer.
638
+ past_cache: For generation only.
639
+
640
+ Returns:
641
+ (batch, seqlen, hidden_dim) if cu_seqlens is None and max_seqlen is None,
642
+ else (total, hidden_dim) where total is the is the sum of the sequence lengths
643
+ in the batch.
644
+
645
+ """
646
+
647
+ if cu_seqlens is not None:
648
+ assert max_seqlen is not None
649
+ assert key_padding_mask is None
650
+ assert self.flash_attn
651
+ # assert self.rotary_emb_dim == 0
652
+
653
+ if key_padding_mask is not None:
654
+ assert cu_seqlens is None
655
+ assert max_seqlen is None
656
+ assert not self.flash_attn
657
+
658
+ if past_cache is not None:
659
+ assert key_padding_mask is None
660
+ assert cu_seqlens is None and max_seqlen is None
661
+
662
+ attn_kwargs = {"key_padding_mask": key_padding_mask}
663
+
664
+ assert x_kv is None and mixer_subset is None
665
+
666
+ qkv = self.Wqkv(x)
667
+ qkv = rearrange(
668
+ qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim
669
+ )
670
+
671
+ if past_cache is None:
672
+ if self.rotary_emb_dim > 0:
673
+ qkv = self.rotary_emb(qkv)
674
+ context = self.inner_attn(
675
+ qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, **attn_kwargs
676
+ )
677
+
678
+ else:
679
+ if self.rotary_emb_dim > 0:
680
+ qkv = self.rotary_emb(qkv, seqlen_offset=past_cache.sequence_len_offset)
681
+ q = qkv[:, :, 0]
682
+ kv = self._update_kv_cache(qkv[:, :, 1:], past_cache)
683
+ # If we're processing the prompt, causal=None (use self.causal).
684
+ # If we're decoding, then causal=False.
685
+ causal = None if past_cache.sequence_len_offset == 0 else False
686
+ context = self.inner_cross_attn(q, kv, causal=causal)
687
+
688
+ out = rearrange(context, "... h d -> ... (h d)")
689
+ out = self.out_proj(out)
690
+
691
+ return out if not self.return_residual else (out, x)
692
+
693
+
694
+ class ParallelBlock(nn.Module):
695
+ """Parallel block.
696
+
697
+ This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
698
+
699
+ """
700
+
701
+ def __init__(
702
+ self,
703
+ config: PretrainedConfig,
704
+ mixer: Optional[Dict[str, Any]] = None,
705
+ mlp: Optional[Dict[str, Any]] = None,
706
+ block_idx: Optional[int] = None,
707
+ ) -> None:
708
+ super().__init__()
709
+
710
+ self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
711
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
712
+ self.block_idx = block_idx
713
+
714
+ self.mixer = MHA(config=config, **mixer, layer_idx=block_idx)
715
+ mlp_cls = mlp.pop("mlp_cls")
716
+ if mlp_cls == "fused_mlp":
717
+ self.mlp = FusedMLP(config=config, **mlp)
718
+ else:
719
+ self.mlp = MLP(config=config, **mlp)
720
+
721
+ def forward(
722
+ self,
723
+ hidden_states: torch.FloatTensor,
724
+ past_cache: Optional[torch.FloatTensor] = None,
725
+ cu_seqlens: Optional[torch.LongTensor] = None,
726
+ max_seqlen: Optional[int] = None,
727
+ ) -> torch.FloatTensor:
728
+ residual = hidden_states
729
+ hidden_states = self.ln(hidden_states)
730
+
731
+ attn_outputs = self.mixer(
732
+ hidden_states,
733
+ past_cache=past_cache,
734
+ cu_seqlens=cu_seqlens,
735
+ max_seqlen=max_seqlen,
736
+ )
737
+ if isinstance(attn_outputs, tuple):
738
+ attn_outputs = attn_outputs[0]
739
+
740
+ attn_outputs = self.resid_dropout(attn_outputs)
741
+ feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
742
+
743
+ hidden_states = attn_outputs + feed_forward_hidden_states + residual
744
+
745
+ return hidden_states
746
+
747
+
748
+ class CausalLMHead(nn.Module):
749
+ """Causal Language Modeling head.
750
+
751
+ Reference:
752
+ Improving Language Understanding by Generative Pre-Training.
753
+ https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
754
+
755
+ """
756
+
757
+ def __init__(self, config: PretrainedConfig) -> None:
758
+ super().__init__()
759
+
760
+ self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
761
+ self.linear = nn.Linear(config.n_embd, config.vocab_size)
762
+
763
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
764
+ hidden_states = self.ln(hidden_states)
765
+ logits = self.linear(hidden_states).to(torch.float32)
766
+
767
+ return logits
768
+
769
+
770
+ class CausalLMLoss(nn.Module):
771
+ """Causal Language Modeling loss.
772
+
773
+ Reference:
774
+ Improving Language Understanding by Generative Pre-Training.
775
+ https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
776
+
777
+ """
778
+
779
+ def __init__(self, shift_labels: Optional[bool] = True) -> None:
780
+ super().__init__()
781
+
782
+ self.shift_labels = shift_labels
783
+ self.loss_fct = nn.CrossEntropyLoss()
784
+
785
+ def forward(
786
+ self, logits: torch.FloatTensor, labels: torch.LongTensor
787
+ ) -> torch.FloatTensor:
788
+ if self.shift_labels:
789
+ logits = logits[..., :-1, :].contiguous()
790
+ labels = labels[..., 1:].contiguous()
791
+
792
+ loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
793
+
794
+ return loss
795
+
796
+
797
+ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
798
+ """MixFormer (sequential for DeepSpeed) pre-trained model."""
799
+
800
+ config_class = MixFormerSequentialConfig
801
+ base_model_prefix = "transformer"
802
+ supports_gradient_checkpointing = True
803
+
804
+ def __init__(self, *inputs, **kwargs) -> None:
805
+ super().__init__(*inputs, **kwargs)
806
+
807
+ def prepare_inputs_for_generation(
808
+ self, input_ids, past_key_values=None, **kwargs
809
+ ) -> Dict[str, Any]:
810
+ if "use_cache" in kwargs and not kwargs["use_cache"]:
811
+ return {"input_ids": input_ids}
812
+
813
+ if past_key_values is None or not (
814
+ isinstance(past_key_values, InferenceParams)
815
+ ):
816
+ past_key_values = InferenceParams(
817
+ max_batch_size=input_ids.shape[0],
818
+ max_sequence_len=self.config.n_positions,
819
+ sequence_len_offset=0,
820
+ batch_size_offset=0,
821
+ fused_ft_kernel=False,
822
+ key_value_memory_dict={},
823
+ )
824
+ else:
825
+ # assume past_key_values has cached all but last token in input_ids
826
+ past_key_values.sequence_len_offset = len(input_ids[0]) - 1
827
+ input_ids = input_ids[:, -1].unsqueeze(-1)
828
+
829
+ return {"input_ids": input_ids, "past_key_values": past_key_values, **kwargs}
830
+
831
+
832
+ class PackedSequential(nn.Sequential):
833
+ def forward(
834
+ self,
835
+ input,
836
+ cu_seqlens: Optional[torch.LongTensor] = None,
837
+ max_seqlen: Optional[int] = None,
838
+ ):
839
+ for module in self:
840
+ sig = inspect.signature(module.forward)
841
+ if "cu_seqlens" in sig.parameters:
842
+ input = module(input, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
843
+ else:
844
+ input = module(input)
845
+ return input
846
+
847
+
848
+ class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
849
+ """MixFormer (sequential for DeepSpeed) for Causal Language Modeling."""
850
+
851
+ _keys_to_ignore_on_load_missing = [""]
852
+ _keys_to_ignore_on_load_unexpected = [
853
+ r"layers\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"
854
+ ]
855
+ _no_split_modules = ["ParallelBlock"]
856
+
857
+ def __init__(self, config: MixFormerSequentialConfig) -> None:
858
+ super().__init__(config)
859
+
860
+ modules = [Embedding(config)]
861
+ block_config = config.architecture
862
+
863
+ if not isinstance(block_config, list):
864
+ block_config = [block_config for _ in range(config.n_layer)]
865
+
866
+ if config.n_layer != len(block_config):
867
+ config.n_layer = len(block_config)
868
+
869
+ for block_idx, block in enumerate(block_config):
870
+ # `block_cls` with `legacy` value is for backward compatibility
871
+ # `path` key is for backward compatibility
872
+ block = copy.deepcopy(block) or {"block_cls": "parallel"}
873
+ # block_cls = block.pop("path", None) or block.pop("block_cls", None)
874
+
875
+ block["block_idx"] = block_idx
876
+ modules.append(ParallelBlock(config, **block))
877
+
878
+ modules.append(CausalLMHead(config))
879
+
880
+ self.layers = PackedSequential(*modules)
881
+ self.loss = CausalLMLoss()
882
+
883
+ self.post_init()
884
+
885
+ def get_input_embeddings(self) -> nn.Embedding:
886
+ return self.layers[0].wte
887
+
888
+ def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
889
+ self.layers[0].wte = new_embeddings
890
+
891
+ def get_output_embeddings(self) -> nn.Linear:
892
+ return self.layers[-1].linear
893
+
894
+ def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
895
+ self.layers[-1].linear = new_embeddings
896
+
897
+ def forward(
898
+ self,
899
+ input_ids: torch.LongTensor,
900
+ labels: Optional[torch.LongTensor] = None,
901
+ past_key_values: Optional[torch.FloatTensor] = None,
902
+ position_ids: Optional[torch.LongTensor] = None,
903
+ **kwargs,
904
+ ) -> CausalLMOutputWithPast:
905
+ cu_seqlens: Optional[torch.LongTensor] = None
906
+ max_seqlen: Optional[int] = None
907
+ if position_ids is not None:
908
+ batch_size, seq_length = input_ids.shape
909
+ position_ids = position_ids.view(-1, seq_length).long()
910
+ cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
911
+ cu_seqlens = cu_seqlens.squeeze()
912
+
913
+ if not past_key_values:
914
+ lm_logits = self.layers(
915
+ input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
916
+ )
917
+ else:
918
+ hidden_layer = self.layers[0](input_ids)
919
+ for module in self.layers[1:-1]:
920
+ hidden_layer = module(
921
+ hidden_layer,
922
+ past_cache=past_key_values,
923
+ cu_seqlens=cu_seqlens,
924
+ max_seqlen=max_seqlen,
925
+ )
926
+ lm_logits = self.layers[-1](hidden_layer)
927
+
928
+ loss = None
929
+ if labels is not None:
930
+ loss = self.loss(lm_logits, labels)
931
+
932
+ return CausalLMOutputWithPast(
933
+ loss=loss, logits=lm_logits, past_key_values=past_key_values
934
+ )
src/axolotl/utils/models.py CHANGED
@@ -221,6 +221,17 @@ def load_model(
221
  # device=cfg.device,
222
  # )
223
  # model.train() # sets to train instead of eval mode
 
 
 
 
 
 
 
 
 
 
 
224
  elif model_type and not cfg.trust_remote_code:
225
  if cfg.gptq:
226
  model = AutoModelForCausalLM.from_pretrained(
 
221
  # device=cfg.device,
222
  # )
223
  # model.train() # sets to train instead of eval mode
224
+ elif model_type == "MixFormerSequentialForCausalLM":
225
+ from axolotl.models.phi import MixFormerSequentialForCausalLM
226
+
227
+ model = MixFormerSequentialForCausalLM.from_pretrained(
228
+ base_model,
229
+ device_map=cfg.device_map,
230
+ load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
231
+ load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
232
+ torch_dtype=cfg.torch_dtype,
233
+ **model_kwargs,
234
+ )
235
  elif model_type and not cfg.trust_remote_code:
236
  if cfg.gptq:
237
  model = AutoModelForCausalLM.from_pretrained(
tests/e2e/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ last_run_prepared
tests/e2e/test_lora_llama.py CHANGED
@@ -7,39 +7,23 @@ import os
7
  import tempfile
8
  import unittest
9
 
 
10
  from axolotl.common.cli import TrainerCliArgs
11
- from axolotl.train import TrainDatasetMeta, train
12
  from axolotl.utils.config import normalize_config
13
- from axolotl.utils.data import prepare_dataset
14
  from axolotl.utils.dict import DictDefault
15
- from axolotl.utils.models import load_tokenizer
16
 
17
  LOG = logging.getLogger("axolotl.tests.e2e")
18
  os.environ["WANDB_DISABLED"] = "true"
19
 
20
 
21
- def load_datasets(
22
- *,
23
- cfg: DictDefault,
24
- cli_args: TrainerCliArgs, # pylint:disable=unused-argument
25
- ) -> TrainDatasetMeta:
26
- tokenizer = load_tokenizer(cfg)
27
-
28
- train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
29
-
30
- return TrainDatasetMeta(
31
- train_dataset=train_dataset,
32
- eval_dataset=eval_dataset,
33
- total_num_steps=total_num_steps,
34
- )
35
-
36
-
37
  class TestLoraLlama(unittest.TestCase):
38
  """
39
  Test case for Llama models using LoRA
40
  """
41
 
42
  def test_lora(self):
 
43
  cfg = DictDefault(
44
  {
45
  "base_model": "JackFram/llama-68m",
@@ -80,6 +64,7 @@ class TestLoraLlama(unittest.TestCase):
80
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
81
 
82
  def test_lora_packing(self):
 
83
  cfg = DictDefault(
84
  {
85
  "base_model": "JackFram/llama-68m",
 
7
  import tempfile
8
  import unittest
9
 
10
+ from axolotl.cli import load_datasets
11
  from axolotl.common.cli import TrainerCliArgs
12
+ from axolotl.train import train
13
  from axolotl.utils.config import normalize_config
 
14
  from axolotl.utils.dict import DictDefault
 
15
 
16
  LOG = logging.getLogger("axolotl.tests.e2e")
17
  os.environ["WANDB_DISABLED"] = "true"
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  class TestLoraLlama(unittest.TestCase):
21
  """
22
  Test case for Llama models using LoRA
23
  """
24
 
25
  def test_lora(self):
26
+ # pylint: disable=duplicate-code
27
  cfg = DictDefault(
28
  {
29
  "base_model": "JackFram/llama-68m",
 
64
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
65
 
66
  def test_lora_packing(self):
67
+ # pylint: disable=duplicate-code
68
  cfg = DictDefault(
69
  {
70
  "base_model": "JackFram/llama-68m",
tests/e2e/test_phi.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ E2E tests for lora llama
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ import tempfile
8
+ import unittest
9
+
10
+ from axolotl.cli import load_datasets
11
+ from axolotl.common.cli import TrainerCliArgs
12
+ from axolotl.train import train
13
+ from axolotl.utils.config import normalize_config
14
+ from axolotl.utils.dict import DictDefault
15
+
16
+ LOG = logging.getLogger("axolotl.tests.e2e")
17
+ os.environ["WANDB_DISABLED"] = "true"
18
+
19
+
20
+ class TestPhi(unittest.TestCase):
21
+ """
22
+ Test case for Llama models using LoRA
23
+ """
24
+
25
+ def test_ft(self):
26
+ # pylint: disable=duplicate-code
27
+ cfg = DictDefault(
28
+ {
29
+ "base_model": "microsoft/phi-1_5",
30
+ "base_model_config": "microsoft/phi-1_5",
31
+ "trust_remote_code": True,
32
+ "model_type": "MixFormerSequentialForCausalLM",
33
+ "tokenizer_type": "AutoTokenizer",
34
+ "sequence_len": 2048,
35
+ "sample_packing": False,
36
+ "load_in_8bit": True,
37
+ "adapter": None,
38
+ "val_set_size": 0.1,
39
+ "special_tokens": {
40
+ "unk_token": "<|endoftext|>",
41
+ "bos_token": "<|endoftext|>",
42
+ "eos_token": "<|endoftext|>",
43
+ "pad_token": "<|endoftext|>",
44
+ },
45
+ "datasets": [
46
+ {
47
+ "path": "mhenrichsen/alpaca_2k_test",
48
+ "type": "alpaca",
49
+ },
50
+ ],
51
+ "dataset_shard_num": 10,
52
+ "dataset_shard_idx": 0,
53
+ "num_epochs": 1,
54
+ "micro_batch_size": 1,
55
+ "gradient_accumulation_steps": 1,
56
+ "output_dir": tempfile.mkdtemp(),
57
+ "learning_rate": 0.00001,
58
+ "optimizer": "adamw_torch",
59
+ "lr_scheduler": "cosine",
60
+ }
61
+ )
62
+ normalize_config(cfg)
63
+ cli_args = TrainerCliArgs()
64
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
65
+
66
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
67
+
68
+ def test_ft_packed(self):
69
+ # pylint: disable=duplicate-code
70
+ cfg = DictDefault(
71
+ {
72
+ "base_model": "microsoft/phi-1_5",
73
+ "base_model_config": "microsoft/phi-1_5",
74
+ "trust_remote_code": True,
75
+ "model_type": "MixFormerSequentialForCausalLM",
76
+ "tokenizer_type": "AutoTokenizer",
77
+ "sequence_len": 2048,
78
+ "sample_packing": True,
79
+ "load_in_8bit": True,
80
+ "adapter": None,
81
+ "val_set_size": 0.1,
82
+ "special_tokens": {
83
+ "unk_token": "<|endoftext|>",
84
+ "bos_token": "<|endoftext|>",
85
+ "eos_token": "<|endoftext|>",
86
+ "pad_token": "<|endoftext|>",
87
+ },
88
+ "datasets": [
89
+ {
90
+ "path": "mhenrichsen/alpaca_2k_test",
91
+ "type": "alpaca",
92
+ },
93
+ ],
94
+ "dataset_shard_num": 10,
95
+ "dataset_shard_idx": 0,
96
+ "num_epochs": 1,
97
+ "micro_batch_size": 1,
98
+ "gradient_accumulation_steps": 1,
99
+ "output_dir": tempfile.mkdtemp(),
100
+ "learning_rate": 0.00001,
101
+ "optimizer": "adamw_torch",
102
+ "lr_scheduler": "cosine",
103
+ }
104
+ )
105
+ normalize_config(cfg)
106
+ cli_args = TrainerCliArgs()
107
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
108
+
109
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)