chenkq commited on
Commit
54b93e0
1 Parent(s): 2d9f231

Delete util.py

Browse files
Files changed (1) hide show
  1. util.py +0 -483
util.py DELETED
@@ -1,483 +0,0 @@
1
- from typing import Optional, Tuple, Union
2
-
3
- import torch
4
- from einops import rearrange, repeat
5
- import torch.nn.functional as F
6
-
7
- import triton
8
- import triton.language as tl
9
-
10
-
11
- # @triton.autotune(
12
- # configs=[
13
- # triton.Config({"BLOCK_M": 2}),
14
- # triton.Config({"BLOCK_M": 4}),
15
- # triton.Config({"BLOCK_M": 8}),
16
- # triton.Config({"BLOCK_M": 16}),
17
- # ],
18
- # key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"],
19
- # )
20
- @triton.jit
21
- def rotary_kernel(
22
- OUT, # Pointers to matrices
23
- X,
24
- COS,
25
- SIN,
26
- CU_SEQLENS,
27
- SEQLEN_OFFSETS, # this could be int or a pointer
28
- # Matrix dimensions
29
- seqlen,
30
- nheads,
31
- rotary_dim,
32
- seqlen_ro,
33
- CACHE_KEY_SEQLEN,
34
- # strides
35
- stride_out_batch,
36
- stride_out_nheads,
37
- stride_out_seqlen,
38
- stride_out_headdim,
39
- stride_x_batch,
40
- stride_x_nheads,
41
- stride_x_seqlen,
42
- stride_x_headdim,
43
- # Meta-parameters
44
- BLOCK_K: tl.constexpr,
45
- IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
46
- IS_VARLEN: tl.constexpr,
47
- INTERLEAVED: tl.constexpr,
48
- CONJUGATE: tl.constexpr,
49
- BLOCK_M: tl.constexpr,
50
- ):
51
- pid_m = tl.program_id(axis=0)
52
- pid_batch = tl.program_id(axis=1)
53
- pid_head = tl.program_id(axis=2)
54
- rotary_dim_half = rotary_dim // 2
55
-
56
- if not IS_VARLEN:
57
- X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
58
- OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
59
- COS = COS + pid_batch * seqlen_ro * rotary_dim_half
60
- SIN = SIN + pid_batch * seqlen_ro * rotary_dim_half
61
- else:
62
- start_idx = tl.load(CU_SEQLENS + pid_batch)
63
- seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
64
- X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads
65
- OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads
66
-
67
- if pid_m * BLOCK_M >= seqlen:
68
- return
69
- rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
70
- if not IS_SEQLEN_OFFSETS_TENSOR:
71
- rm_cs = rm + SEQLEN_OFFSETS
72
- else:
73
- rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
74
- rk = tl.arange(0, BLOCK_K)
75
- rk_half = tl.arange(0, BLOCK_K // 2)
76
-
77
- if not INTERLEAVED:
78
- # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
79
- X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)
80
- COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
81
- SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
82
- cos = tl.load(
83
- COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0
84
- )
85
- sin = tl.load(
86
- SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0
87
- )
88
- x0 = tl.load(
89
- X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0
90
- )
91
- x1 = tl.load(
92
- X + rotary_dim_half * stride_x_headdim,
93
- mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
94
- other=0.0,
95
- )
96
- if CONJUGATE:
97
- sin = -sin
98
- o0 = x0 * cos - x1 * sin
99
- o1 = x0 * sin + x1 * cos
100
- # write back result
101
- OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
102
- tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
103
- tl.store(
104
- OUT + rotary_dim_half * stride_out_headdim,
105
- o1,
106
- mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
107
- )
108
- else:
109
- # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.
110
- # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
111
- # Loading x0 will be fast but x1 will be slow.
112
- # Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].
113
- # Then we do the calculation and use tl.where to pick put the right outputs for the even
114
- # and for the odd indices.
115
- rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
116
- rk_repeat = tl.arange(0, BLOCK_K) // 2
117
- X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)
118
- X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)
119
- COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
120
- SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
121
- cos = tl.load(
122
- COS,
123
- mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
124
- other=1.0,
125
- ).to(tl.float32)
126
- sin = tl.load(
127
- SIN,
128
- mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
129
- other=0.0,
130
- ).to(tl.float32)
131
- x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(
132
- tl.float32
133
- )
134
- x1 = tl.load(
135
- X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
136
- ).to(tl.float32)
137
- if CONJUGATE:
138
- sin = -sin
139
- x0_cos = x0 * cos
140
- x1_sin = x1 * sin
141
- out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
142
- OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
143
- tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
144
-
145
-
146
- def apply_rotary(
147
- x: torch.Tensor,
148
- cos: torch.Tensor,
149
- sin: torch.Tensor,
150
- seqlen_offsets: Union[int, torch.Tensor] = 0,
151
- cu_seqlens: Optional[torch.Tensor] = None,
152
- max_seqlen: Optional[int] = None,
153
- interleaved=False,
154
- inplace=False,
155
- conjugate=False,
156
- ) -> torch.Tensor:
157
- """
158
- Arguments:
159
- x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
160
- else (total_seqlen, nheads, headdim).
161
- cos: (seqlen_ro, rotary_dim / 2)
162
- sin: (seqlen_ro, rotary_dim / 2)
163
- seqlen_offsets: integer or integer tensor of size (batch,)
164
- cu_seqlens: (batch + 1,) or None
165
- max_seqlen: int
166
- Returns:
167
- y: (batch, seqlen, nheads, headdim)
168
- """
169
-
170
- batch, nheads, seqlen, headdim = x.shape
171
-
172
- batch_ro, seqlen_ro, rotary_dim = cos.shape
173
-
174
- assert batch == batch_ro
175
- assert sin.shape == cos.shape
176
- rotary_dim *= 2
177
- assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
178
- assert headdim <= 256, "Only support headdim <= 256"
179
-
180
- assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
181
-
182
- assert (
183
- cos.dtype == sin.dtype
184
- ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
185
- assert (
186
- x.dtype == cos.dtype
187
- ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
188
-
189
- cos, sin = cos.contiguous(), sin.contiguous()
190
- if isinstance(seqlen_offsets, torch.Tensor):
191
- assert seqlen_offsets.shape == (batch,)
192
- assert seqlen_offsets.dtype in [torch.int32, torch.int64]
193
- seqlen_offsets = seqlen_offsets.contiguous()
194
- else:
195
- assert seqlen_offsets + seqlen <= seqlen_ro
196
-
197
- output = torch.empty_like(x) if not inplace else x
198
- if rotary_dim < headdim and not inplace:
199
- output[..., rotary_dim:].copy_(x[..., rotary_dim:])
200
-
201
- BLOCK_K = (
202
- 32
203
- if rotary_dim <= 32
204
- else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
205
- )
206
- grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
207
- BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
208
-
209
- # Need this, otherwise Triton tries to launch from cuda:0 and we get
210
- # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
211
- with torch.cuda.device(x.device.index):
212
- rotary_kernel[grid](
213
- output, # data ptrs
214
- x,
215
- cos,
216
- sin,
217
- cu_seqlens,
218
- seqlen_offsets,
219
- seqlen, # shapes
220
- nheads,
221
- rotary_dim,
222
- seqlen_ro,
223
- seqlen // 128, # key for triton cache (limit number of compilations)
224
- output.stride(0), # batch_strides
225
- output.stride(-3), # nheads_stride
226
- output.stride(-2), # seqlen_stride
227
- output.stride(-1), # headdim_stride
228
- x.stride(0), # batch_strides
229
- x.stride(-3), # nheads stride
230
- x.stride(-2), # seqlen stride
231
- x.stride(-1), # headdim stride
232
- BLOCK_K,
233
- isinstance(seqlen_offsets, torch.Tensor),
234
- False,
235
- interleaved,
236
- conjugate,
237
- BLOCK_M,
238
- )
239
- return output
240
-
241
-
242
- class ApplyRotaryEmb(torch.autograd.Function):
243
- @staticmethod
244
- def forward(
245
- ctx,
246
- x,
247
- cos,
248
- sin,
249
- interleaved=False,
250
- inplace=False,
251
- seqlen_offsets: Union[int, torch.Tensor] = 0,
252
- cu_seqlens: Optional[torch.Tensor] = None,
253
- max_seqlen: Optional[int] = None,
254
- ):
255
- out = apply_rotary(
256
- x,
257
- cos,
258
- sin,
259
- seqlen_offsets=seqlen_offsets,
260
- cu_seqlens=cu_seqlens,
261
- max_seqlen=max_seqlen,
262
- interleaved=interleaved,
263
- inplace=inplace,
264
- )
265
- if isinstance(seqlen_offsets, int):
266
- ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
267
- ctx.seqlen_offsets = seqlen_offsets
268
- else:
269
- ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
270
- ctx.seqlen_offsets = None
271
- ctx.interleaved = interleaved
272
- ctx.inplace = inplace
273
- ctx.max_seqlen = max_seqlen
274
- return out if not inplace else x
275
-
276
- @staticmethod
277
- def backward(ctx, do):
278
- seqlen_offsets = ctx.seqlen_offsets
279
- if seqlen_offsets is None:
280
- cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
281
- else:
282
- cos, sin, cu_seqlens = ctx.saved_tensors
283
- # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
284
- # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
285
- if not ctx.interleaved and not ctx.inplace:
286
- do = do.clone()
287
- dx = apply_rotary(
288
- do,
289
- cos,
290
- sin,
291
- seqlen_offsets=seqlen_offsets,
292
- cu_seqlens=cu_seqlens,
293
- max_seqlen=ctx.max_seqlen,
294
- interleaved=ctx.interleaved,
295
- inplace=ctx.inplace,
296
- conjugate=True,
297
- )
298
- return dx, None, None, None, None, None, None, None
299
-
300
-
301
- def apply_rotary_emb(
302
- x,
303
- cos,
304
- sin,
305
- interleaved=False,
306
- inplace=False,
307
- seqlen_offsets: Union[int, torch.Tensor] = 0,
308
- cu_seqlens: Optional[torch.Tensor] = None,
309
- max_seqlen: Optional[int] = None,
310
- ):
311
- """
312
- Arguments:
313
- x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
314
- else (total_seqlen, nheads, headdim)
315
- cos, sin: (seqlen_rotary, rotary_dim / 2)
316
- interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
317
- of 1st half and 2nd half (GPT-NeoX style).
318
- inplace: if True, apply rotary embedding in-place.
319
- seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
320
- Most commonly used in inference when we have KV cache.
321
- cu_seqlens: (batch + 1,) or None
322
- max_seqlen: int
323
- Return:
324
- out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
325
- else (total_seqlen, nheads, headdim)
326
- rotary_dim must be <= headdim
327
- Apply rotary embedding to the first rotary_dim of x.
328
- """
329
- return ApplyRotaryEmb.apply(
330
- x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
331
- )
332
-
333
-
334
- # For backward compatibility
335
- apply_rotary_emb_func = apply_rotary_emb
336
-
337
-
338
- class FastRotaryEmbedding(torch.nn.Module):
339
- """
340
- The rotary position embeddings from RoFormer_ (Su et. al).
341
- A crucial insight from the method is that the query and keys are
342
- transformed by rotation matrices which depend on the relative positions.
343
-
344
- Other implementations are available in the Rotary Transformer repo_ and in
345
- GPT-NeoX_, GPT-NeoX was an inspiration
346
-
347
- .. _RoFormer: https://arxiv.org/abs/2104.09864
348
- .. _repo: https://github.com/ZhuiyiTechnology/roformer
349
- .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
350
-
351
- If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
352
- A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
353
- Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
354
- """
355
-
356
- def __init__(
357
- self,
358
- dim: int,
359
- base=10000,
360
- interleaved=False,
361
- scale_base=None,
362
- pos_idx_in_fp32=True,
363
- device=None,
364
- ):
365
- """
366
- interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
367
- of 1st half and 2nd half (GPT-NeoX style).
368
- pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
369
- otherwise they might be in lower precision.
370
- This option was added because previously (before 2023-07-02), when we construct
371
- the position indices, we use the dtype of self.inv_freq. In most cases this would
372
- be fp32, but if the model is trained in pure bf16 (not mixed precision), then
373
- self.inv_freq would be bf16, and the position indices are also in bf16.
374
- Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
375
- embeddings for some positions will coincide.
376
- To maintain compatibility with models previously trained in pure bf16,
377
- we add this option.
378
- """
379
- super().__init__()
380
- self.dim = dim
381
- self.base = base
382
- self.pos_idx_in_fp32 = pos_idx_in_fp32
383
- # Generate and save the inverse frequency buffer (non trainable)
384
- inv_freq = self._compute_inv_freq(device)
385
- self.register_buffer("inv_freq", inv_freq)
386
- self.interleaved = interleaved
387
- self.scale_base = scale_base
388
- scale = (
389
- (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
390
- if scale_base is not None
391
- else None
392
- )
393
- self.register_buffer("scale", scale, persistent=False)
394
-
395
- self._seq_len_cached = 0
396
- self._cos_cached = None
397
- self._sin_cached = None
398
- self._cos_k_cached = None
399
- self._sin_k_cached = None
400
- self.cos = None
401
- self.sin = None
402
-
403
- def _compute_inv_freq(self, device=None):
404
- return 1.0 / (
405
- self.base
406
- ** (torch.arange(0, self.dim, 2, device=device) / self.dim)
407
- # ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim)
408
- )
409
-
410
- def _update_cos_sin_cache(self, seqlen, position_id, device=None, dtype=None):
411
-
412
- if (
413
- seqlen > self._seq_len_cached
414
- ):
415
- self._seq_len_cached = seqlen
416
- # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
417
- # And the output of arange can be quite large, so bf16 would lose a lot of precision.
418
- # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
419
- if self.pos_idx_in_fp32:
420
- t = torch.arange(seqlen, device=device, dtype=torch.float32)
421
- # We want fp32 here as well since inv_freq will be multiplied with t, and the output
422
- # will be large. Having it in bf16 will lose a lot of precision and cause the
423
- # cos & sin output to change significantly.
424
- # We want to recompute self.inv_freq if it was not loaded in fp32
425
- if self.inv_freq.dtype != torch.float32:
426
- inv_freq = self._compute_inv_freq(device=device)
427
- else:
428
- inv_freq = self.inv_freq
429
- else:
430
- t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
431
- inv_freq = self.inv_freq
432
- freqs = torch.einsum("i,j->ij", t, inv_freq)
433
- if self.scale is None:
434
- self._cos_cached = torch.cos(freqs).to(dtype)
435
- self._sin_cached = torch.sin(freqs).to(dtype)
436
-
437
- else:
438
- power = (
439
- torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
440
- - seqlen // 2
441
- ) / self.scale_base
442
- scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
443
- # We want the multiplication by scale to happen in fp32
444
- self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
445
- self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
446
- self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
447
- self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
448
-
449
- def forward(
450
- self,
451
- q: torch.Tensor,
452
- k: torch.Tensor,
453
- position_ids: torch.Tensor,
454
- max_seqlen,
455
- ) -> Tuple[torch.Tensor, torch.Tensor]:
456
- """
457
- q: (batch, nheads, seqlen, headdim)
458
- k: (batch, nheads, seqlen, headdim)
459
- position_id: (batch, seqlen)
460
- max_seqlen: int
461
- layer_id: int
462
- only if layer_id == 0, then update cons and sin
463
- Apply rotary embedding *inplace* to q k.
464
- """
465
-
466
- self._update_cos_sin_cache(max_seqlen, position_ids, device=q.device, dtype=q.dtype)
467
- cos, sin = F.embedding(position_ids, self._cos_cached), F.embedding(position_ids, self._sin_cached)
468
-
469
- q = apply_rotary_emb_func(
470
- q,
471
- cos,
472
- sin,
473
- interleaved=self.interleaved,
474
- inplace=True
475
- )
476
- k = apply_rotary_emb_func(
477
- k,
478
- cos,
479
- sin,
480
- interleaved=self.interleaved,
481
- inplace=True
482
- )
483
- return q, k