nyanko7 commited on
Commit
43f0396
1 Parent(s): 8d7886f

Create model.py

Browse files
Files changed (1) hide show
  1. modules/model.py +963 -0
modules/model.py ADDED
@@ -0,0 +1,963 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import inspect
3
+ import math
4
+ from pathlib import Path
5
+ import re
6
+ from collections import defaultdict
7
+ from typing import List, Optional, Union
8
+
9
+ import k_diffusion
10
+ import numpy as np
11
+ import PIL
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from einops import rearrange
16
+ from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
17
+ from modules.prompt_parser import FrozenCLIPEmbedderWithCustomWords
18
+ from torch import einsum
19
+ from torch.autograd.function import Function
20
+
21
+ from diffusers import DiffusionPipeline
22
+ from diffusers.utils import PIL_INTERPOLATION, is_accelerate_available
23
+ from diffusers.utils import logging, randn_tensor
24
+
25
+ import modules.safe as _
26
+ from safetensors.torch import load_file
27
+
28
+ xformers_available = False
29
+ try:
30
+ import xformers
31
+ xformers_available = True
32
+ except ImportError:
33
+ pass
34
+
35
+ EPSILON = 1e-6
36
+ exists = lambda val: val is not None
37
+ default = lambda val, d: val if exists(val) else d
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+ def get_attention_scores(attn, query, key, attention_mask=None):
41
+
42
+ if attn.upcast_attention:
43
+ query = query.float()
44
+ key = key.float()
45
+
46
+ attention_scores = torch.baddbmm(
47
+ torch.empty(
48
+ query.shape[0],
49
+ query.shape[1],
50
+ key.shape[1],
51
+ dtype=query.dtype,
52
+ device=query.device,
53
+ ),
54
+ query,
55
+ key.transpose(-1, -2),
56
+ beta=0,
57
+ alpha=attn.scale,
58
+ )
59
+
60
+ if attention_mask is not None:
61
+ attention_scores = attention_scores + attention_mask
62
+
63
+ if attn.upcast_softmax:
64
+ attention_scores = attention_scores.float()
65
+
66
+ return attention_scores
67
+
68
+
69
+ def load_lora_attn_procs(model_file, unet, scale=1.0):
70
+
71
+ if Path(model_file).suffix == ".pt":
72
+ state_dict = torch.load(model_file, map_location="cpu")
73
+ else:
74
+ state_dict = load_file(model_file, device="cpu")
75
+
76
+ # 'lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_q.lora_down.weight'
77
+ # 'down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor.to_q_lora.down.weight'
78
+ if any("lora_unet_down_blocks"in k for k in state_dict.keys()):
79
+ # extract ldm format lora
80
+ df_lora = {}
81
+ attn_numlayer = re.compile(r'_attn(\d)_to_([qkv]|out).lora_')
82
+ alpha_numlayer = re.compile(r'_attn(\d)_to_([qkv]|out).alpha')
83
+ for k, v in state_dict.items():
84
+ if "attn" not in k or "lora_te" in k:
85
+ # currently not support: ff, clip-attn
86
+ continue
87
+ k = k.replace("lora_unet_down_blocks_", "down_blocks.")
88
+ k = k.replace("lora_unet_up_blocks_", "up_blocks.")
89
+ k = k.replace("lora_unet_mid_block_", "mid_block_")
90
+ k = k.replace("_attentions_", ".attentions.")
91
+ k = k.replace("_transformer_blocks_", ".transformer_blocks.")
92
+ k = k.replace("to_out_0", "to_out")
93
+ k = attn_numlayer.sub(r'.attn\1.processor.to_\2_lora.', k)
94
+ k = alpha_numlayer.sub(r'.attn\1.processor.to_\2_lora.alpha', k)
95
+ df_lora[k] = v
96
+ state_dict = df_lora
97
+
98
+ # fill attn processors
99
+ attn_processors = {}
100
+
101
+ is_lora = all("lora" in k for k in state_dict.keys())
102
+
103
+ if is_lora:
104
+ lora_grouped_dict = defaultdict(dict)
105
+ for key, value in state_dict.items():
106
+ if "alpha" in key:
107
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
108
+ else:
109
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
110
+ lora_grouped_dict[attn_processor_key][sub_key] = value
111
+
112
+ for key, value_dict in lora_grouped_dict.items():
113
+ rank = value_dict["to_k_lora.down.weight"].shape[0]
114
+ cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
115
+ hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
116
+
117
+ attn_processors[key] = LoRACrossAttnProcessor(
118
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank, scale=scale
119
+ )
120
+ attn_processors[key].load_state_dict(value_dict, strict=False)
121
+
122
+ else:
123
+ raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")
124
+
125
+ # set correct dtype & device
126
+ attn_processors = {k: v.to(device=unet.device, dtype=unet.dtype) for k, v in attn_processors.items()}
127
+
128
+ # set layers
129
+ unet.set_attn_processor(attn_processors)
130
+
131
+
132
+ class CrossAttnProcessor(nn.Module):
133
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, qkvo_bias=None):
134
+ batch_size, sequence_length, _ = hidden_states.shape
135
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
136
+
137
+ encoder_states = hidden_states
138
+ is_xattn = False
139
+ if encoder_hidden_states is not None:
140
+ is_xattn = True
141
+ img_state = encoder_hidden_states["img_state"]
142
+ encoder_states = encoder_hidden_states["states"]
143
+ weight_func = encoder_hidden_states["weight_func"]
144
+ sigma = encoder_hidden_states["sigma"]
145
+
146
+ query = attn.to_q(hidden_states)
147
+ key = attn.to_k(encoder_states)
148
+ value = attn.to_v(encoder_states)
149
+
150
+ if qkvo_bias is not None:
151
+ query += qkvo_bias["q"](hidden_states)
152
+ key += qkvo_bias["k"](encoder_states)
153
+ value += qkvo_bias["v"](encoder_states)
154
+
155
+ query = attn.head_to_batch_dim(query)
156
+ key = attn.head_to_batch_dim(key)
157
+ value = attn.head_to_batch_dim(value)
158
+
159
+ if is_xattn and isinstance(img_state, dict):
160
+ # use torch.baddbmm method (slow)
161
+ attention_scores = get_attention_scores(attn, query, key, attention_mask)
162
+ w = img_state[sequence_length].to(query.device)
163
+ cross_attention_weight = weight_func(w, sigma, attention_scores)
164
+ attention_scores += torch.repeat_interleave(cross_attention_weight, repeats=attn.heads, dim=0)
165
+
166
+ # calc probs
167
+ attention_probs = attention_scores.softmax(dim=-1)
168
+ attention_probs = attention_probs.to(query.dtype)
169
+ hidden_states = torch.bmm(attention_probs, value)
170
+
171
+ elif xformers_available:
172
+ hidden_states = xformers.ops.memory_efficient_attention(
173
+ query.contiguous(), key.contiguous(), value.contiguous(), attn_bias=attention_mask
174
+ )
175
+ hidden_states = hidden_states.to(query.dtype)
176
+
177
+ else:
178
+ q_bucket_size = 512
179
+ k_bucket_size = 1024
180
+
181
+ # use flash-attention
182
+ hidden_states = FlashAttentionFunction.apply(
183
+ query.contiguous(), key.contiguous(), value.contiguous(),
184
+ attention_mask, causal=False, q_bucket_size=q_bucket_size, k_bucket_size=k_bucket_size
185
+ )
186
+ hidden_states = hidden_states.to(query.dtype)
187
+
188
+ hidden_states = attn.batch_to_head_dim(hidden_states)
189
+
190
+ # linear proj
191
+ hidden_states = attn.to_out[0](hidden_states)
192
+
193
+ if qkvo_bias is not None:
194
+ hidden_states += qkvo_bias["o"](hidden_states)
195
+
196
+ # dropout
197
+ hidden_states = attn.to_out[1](hidden_states)
198
+
199
+ return hidden_states
200
+
201
+
202
+ class LoRACrossAttnProcessor(CrossAttnProcessor):
203
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, scale=1.0):
204
+ super().__init__()
205
+
206
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
207
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
208
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
209
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
210
+ self.scale = scale
211
+
212
+ def __call__(
213
+ self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None,
214
+ ):
215
+ scale = self.scale
216
+ qkvo_bias = {
217
+ "q": lambda inputs: scale * self.to_q_lora(inputs),
218
+ "k": lambda inputs: scale * self.to_k_lora(inputs),
219
+ "v": lambda inputs: scale * self.to_v_lora(inputs),
220
+ "o": lambda inputs: scale * self.to_out_lora(inputs),
221
+ }
222
+ return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask, qkvo_bias)
223
+
224
+
225
+ class LoRALinearLayer(nn.Module):
226
+ def __init__(self, in_features, out_features, rank=4):
227
+ super().__init__()
228
+
229
+ if rank > min(in_features, out_features):
230
+ raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
231
+
232
+ self.down = nn.Linear(in_features, rank, bias=False)
233
+ self.up = nn.Linear(rank, out_features, bias=False)
234
+ self.scale = 1.0
235
+ self.alpha = rank
236
+
237
+ nn.init.normal_(self.down.weight, std=1 / rank)
238
+ nn.init.zeros_(self.up.weight)
239
+
240
+ def forward(self, hidden_states):
241
+ orig_dtype = hidden_states.dtype
242
+ dtype = self.down.weight.dtype
243
+ rank = self.down.out_features
244
+
245
+ down_hidden_states = self.down(hidden_states.to(dtype))
246
+ up_hidden_states = self.up(down_hidden_states) * (self.alpha / rank)
247
+
248
+ return up_hidden_states.to(orig_dtype)
249
+
250
+
251
+ class ModelWrapper:
252
+ def __init__(self, model, alphas_cumprod):
253
+ self.model = model
254
+ self.alphas_cumprod = alphas_cumprod
255
+
256
+ def apply_model(self, *args, **kwargs):
257
+ if len(args) == 3:
258
+ encoder_hidden_states = args[-1]
259
+ args = args[:2]
260
+ if kwargs.get("cond", None) is not None:
261
+ encoder_hidden_states = kwargs.pop("cond")
262
+ return self.model(
263
+ *args, encoder_hidden_states=encoder_hidden_states, **kwargs
264
+ ).sample
265
+
266
+
267
+ class StableDiffusionPipeline(DiffusionPipeline):
268
+
269
+ _optional_components = ["safety_checker", "feature_extractor"]
270
+
271
+ def __init__(
272
+ self,
273
+ vae,
274
+ text_encoder,
275
+ tokenizer,
276
+ unet,
277
+ scheduler,
278
+ ):
279
+ super().__init__()
280
+
281
+ # get correct sigmas from LMS
282
+ self.register_modules(
283
+ vae=vae,
284
+ text_encoder=text_encoder,
285
+ tokenizer=tokenizer,
286
+ unet=unet,
287
+ scheduler=scheduler,
288
+ )
289
+ self.setup_unet(self.unet)
290
+ self.prompt_parser = FrozenCLIPEmbedderWithCustomWords(self.tokenizer, self.text_encoder)
291
+
292
+ def setup_unet(self, unet):
293
+ unet = unet.to(self.device)
294
+ model = ModelWrapper(unet, self.scheduler.alphas_cumprod)
295
+ if self.scheduler.prediction_type == "v_prediction":
296
+ self.k_diffusion_model = CompVisVDenoiser(model)
297
+ else:
298
+ self.k_diffusion_model = CompVisDenoiser(model)
299
+
300
+ def get_scheduler(self, scheduler_type: str):
301
+ library = importlib.import_module("k_diffusion")
302
+ sampling = getattr(library, "sampling")
303
+ return getattr(sampling, scheduler_type)
304
+
305
+ def encode_sketchs(self, state, scale_ratio=8, g_strength=1.0, text_ids=None):
306
+ uncond, cond = text_ids[0], text_ids[1]
307
+
308
+ img_state = []
309
+ if state is None:
310
+ return torch.FloatTensor(0)
311
+
312
+ for k, v in state.items():
313
+ if v["map"] is None:
314
+ continue
315
+
316
+ v_input = self.tokenizer(
317
+ k,
318
+ max_length=self.tokenizer.model_max_length,
319
+ truncation=True,
320
+ add_special_tokens=False,
321
+ ).input_ids
322
+
323
+ dotmap = v["map"] < 255
324
+ arr = torch.from_numpy(dotmap.astype(float) * float(v["weight"]) * g_strength)
325
+ img_state.append((v_input, arr))
326
+
327
+ if len(img_state) == 0:
328
+ return torch.FloatTensor(0)
329
+
330
+ w_tensors = dict()
331
+ cond = cond.tolist()
332
+ uncond = uncond.tolist()
333
+ for layer in self.unet.down_blocks:
334
+ c = int(len(cond))
335
+ w, h = img_state[0][1].shape
336
+ w_r, h_r = w // scale_ratio, h // scale_ratio
337
+
338
+ ret_cond_tensor = torch.zeros((1, int(w_r * h_r), c), dtype=torch.float32)
339
+ ret_uncond_tensor = torch.zeros((1, int(w_r * h_r), c), dtype=torch.float32)
340
+
341
+ for v_as_tokens, img_where_color in img_state:
342
+ is_in = 0
343
+
344
+ ret = F.interpolate(
345
+ img_where_color.unsqueeze(0).unsqueeze(1),
346
+ scale_factor=1 / scale_ratio,
347
+ mode="bilinear",
348
+ align_corners=True,
349
+ ).squeeze().reshape(-1, 1).repeat(1, len(v_as_tokens))
350
+
351
+ for idx, tok in enumerate(cond):
352
+ if cond[idx : idx + len(v_as_tokens)] == v_as_tokens:
353
+ is_in = 1
354
+ ret_cond_tensor[0, :, idx : idx + len(v_as_tokens)] += (ret)
355
+
356
+ for idx, tok in enumerate(uncond):
357
+ if uncond[idx : idx + len(v_as_tokens)] == v_as_tokens:
358
+ is_in = 1
359
+ ret_uncond_tensor[0, :, idx : idx + len(v_as_tokens)] += (ret)
360
+
361
+ if not is_in == 1:
362
+ print(f"tokens {v_as_tokens} not found in text")
363
+
364
+ w_tensors[w_r * h_r] = torch.cat([ret_uncond_tensor, ret_cond_tensor])
365
+ scale_ratio *= 2
366
+
367
+ return w_tensors
368
+
369
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
370
+ r"""
371
+ Enable sliced attention computation.
372
+
373
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
374
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
375
+
376
+ Args:
377
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
378
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
379
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
380
+ `attention_head_dim` must be a multiple of `slice_size`.
381
+ """
382
+ if slice_size == "auto":
383
+ # half the attention head size is usually a good trade-off between
384
+ # speed and memory
385
+ slice_size = self.unet.config.attention_head_dim // 2
386
+ self.unet.set_attention_slice(slice_size)
387
+
388
+ def disable_attention_slicing(self):
389
+ r"""
390
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
391
+ back to computing attention in one step.
392
+ """
393
+ # set slice_size = `None` to disable `attention slicing`
394
+ self.enable_attention_slicing(None)
395
+
396
+ def enable_sequential_cpu_offload(self, gpu_id=0):
397
+ r"""
398
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
399
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
400
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
401
+ """
402
+ if is_accelerate_available():
403
+ from accelerate import cpu_offload
404
+ else:
405
+ raise ImportError("Please install accelerate via `pip install accelerate`")
406
+
407
+ device = torch.device(f"cuda:{gpu_id}")
408
+
409
+ for cpu_offloaded_model in [
410
+ self.unet,
411
+ self.text_encoder,
412
+ self.vae,
413
+ self.safety_checker,
414
+ ]:
415
+ if cpu_offloaded_model is not None:
416
+ cpu_offload(cpu_offloaded_model, device)
417
+
418
+ @property
419
+ def _execution_device(self):
420
+ r"""
421
+ Returns the device on which the pipeline's models will be executed. After calling
422
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
423
+ hooks.
424
+ """
425
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
426
+ return self.device
427
+ for module in self.unet.modules():
428
+ if (
429
+ hasattr(module, "_hf_hook")
430
+ and hasattr(module._hf_hook, "execution_device")
431
+ and module._hf_hook.execution_device is not None
432
+ ):
433
+ return torch.device(module._hf_hook.execution_device)
434
+ return self.device
435
+
436
+ def decode_latents(self, latents):
437
+ latents = latents.to(self.device, dtype=self.vae.dtype)
438
+ latents = 1 / 0.18215 * latents
439
+ image = self.vae.decode(latents).sample
440
+ image = (image / 2 + 0.5).clamp(0, 1)
441
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
442
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
443
+ return image
444
+
445
+ def check_inputs(self, prompt, height, width, callback_steps):
446
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
447
+ raise ValueError(
448
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
449
+ )
450
+
451
+ if height % 8 != 0 or width % 8 != 0:
452
+ raise ValueError(
453
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
454
+ )
455
+
456
+ if (callback_steps is None) or (
457
+ callback_steps is not None
458
+ and (not isinstance(callback_steps, int) or callback_steps <= 0)
459
+ ):
460
+ raise ValueError(
461
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
462
+ f" {type(callback_steps)}."
463
+ )
464
+
465
+ def prepare_latents(
466
+ self,
467
+ batch_size,
468
+ num_channels_latents,
469
+ height,
470
+ width,
471
+ dtype,
472
+ device,
473
+ generator,
474
+ latents=None,
475
+ ):
476
+ shape = (batch_size, num_channels_latents, height // 8, width // 8)
477
+ if latents is None:
478
+ if device.type == "mps":
479
+ # randn does not work reproducibly on mps
480
+ latents = torch.randn(
481
+ shape, generator=generator, device="cpu", dtype=dtype
482
+ ).to(device)
483
+ else:
484
+ latents = torch.randn(
485
+ shape, generator=generator, device=device, dtype=dtype
486
+ )
487
+ else:
488
+ # if latents.shape != shape:
489
+ # raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
490
+ latents = latents.to(device)
491
+
492
+ # scale the initial noise by the standard deviation required by the scheduler
493
+ return latents
494
+
495
+ def preprocess(self, image):
496
+ if isinstance(image, torch.Tensor):
497
+ return image
498
+ elif isinstance(image, PIL.Image.Image):
499
+ image = [image]
500
+
501
+ if isinstance(image[0], PIL.Image.Image):
502
+ w, h = image[0].size
503
+ w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
504
+
505
+ image = [
506
+ np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[
507
+ None, :
508
+ ]
509
+ for i in image
510
+ ]
511
+ image = np.concatenate(image, axis=0)
512
+ image = np.array(image).astype(np.float32) / 255.0
513
+ image = image.transpose(0, 3, 1, 2)
514
+ image = 2.0 * image - 1.0
515
+ image = torch.from_numpy(image)
516
+ elif isinstance(image[0], torch.Tensor):
517
+ image = torch.cat(image, dim=0)
518
+ return image
519
+
520
+ @torch.no_grad()
521
+ def img2img(
522
+ self,
523
+ prompt: Union[str, List[str]],
524
+ num_inference_steps: int = 50,
525
+ guidance_scale: float = 7.5,
526
+ negative_prompt: Optional[Union[str, List[str]]] = None,
527
+ generator: Optional[torch.Generator] = None,
528
+ image: Optional[torch.FloatTensor] = None,
529
+ output_type: Optional[str] = "pil",
530
+ latents=None,
531
+ strength=1.0,
532
+ pww_state=None,
533
+ pww_attn_weight=1.0,
534
+ sampler_name="",
535
+ sampler_opt={},
536
+ scale_ratio=8.0
537
+ ):
538
+ sampler = self.get_scheduler(sampler_name)
539
+ if image is not None:
540
+ image = self.preprocess(image)
541
+ image = image.to(self.vae.device, dtype=self.vae.dtype)
542
+
543
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
544
+ latents = 0.18215 * init_latents
545
+
546
+ # 2. Define call parameters
547
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
548
+ device = self._execution_device
549
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
550
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
551
+ # corresponds to doing no classifier free guidance.
552
+ do_classifier_free_guidance = True
553
+ if guidance_scale <= 1.0:
554
+ raise ValueError("has to use guidance_scale")
555
+
556
+ # 3. Encode input prompt
557
+ text_ids, text_embeddings = self.prompt_parser([negative_prompt, prompt])
558
+ text_embeddings = text_embeddings.to(self.unet.dtype)
559
+
560
+ init_timestep = int(num_inference_steps / min(strength, 0.999)) if strength > 0 else 0
561
+ sigmas = self.get_sigmas(init_timestep, sampler_opt).to(
562
+ text_embeddings.device, dtype=text_embeddings.dtype
563
+ )
564
+
565
+ t_start = max(init_timestep - num_inference_steps, 0)
566
+ sigma_sched = sigmas[t_start:]
567
+
568
+ noise = randn_tensor(
569
+ latents.shape,
570
+ generator=generator,
571
+ device=device,
572
+ dtype=text_embeddings.dtype,
573
+ )
574
+ latents = latents.to(device)
575
+ latents = latents + noise * sigma_sched[0]
576
+
577
+ # 5. Prepare latent variables
578
+ self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
579
+ self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(
580
+ latents.device
581
+ )
582
+
583
+ img_state = self.encode_sketchs(
584
+ pww_state,
585
+ g_strength=pww_attn_weight,
586
+ text_ids=text_ids,
587
+ )
588
+
589
+ def model_fn(x, sigma):
590
+
591
+ latent_model_input = torch.cat([x] * 2)
592
+ weight_func = (
593
+ lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max()
594
+ )
595
+ encoder_state = {
596
+ "img_state": img_state,
597
+ "states": text_embeddings,
598
+ "sigma": sigma[0],
599
+ "weight_func": weight_func,
600
+ }
601
+
602
+ noise_pred = self.k_diffusion_model(
603
+ latent_model_input, sigma, cond=encoder_state
604
+ )
605
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
606
+ noise_pred = noise_pred_uncond + guidance_scale * (
607
+ noise_pred_text - noise_pred_uncond
608
+ )
609
+ return noise_pred
610
+
611
+ sampler_args = self.get_sampler_extra_args_i2i(sigma_sched, sampler)
612
+ latents = sampler(model_fn, latents, **sampler_args)
613
+
614
+ # 8. Post-processing
615
+ image = self.decode_latents(latents)
616
+
617
+ # 10. Convert to PIL
618
+ if output_type == "pil":
619
+ image = self.numpy_to_pil(image)
620
+
621
+ return (image,)
622
+
623
+ def get_sigmas(self, steps, params):
624
+ discard_next_to_last_sigma = params.get("discard_next_to_last_sigma", False)
625
+ steps += 1 if discard_next_to_last_sigma else 0
626
+
627
+ if params.get("scheduler", None) == "karras":
628
+ sigma_min, sigma_max = (
629
+ self.k_diffusion_model.sigmas[0].item(),
630
+ self.k_diffusion_model.sigmas[-1].item(),
631
+ )
632
+ sigmas = k_diffusion.sampling.get_sigmas_karras(
633
+ n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=self.device
634
+ )
635
+ else:
636
+ sigmas = self.k_diffusion_model.get_sigmas(steps)
637
+
638
+ if discard_next_to_last_sigma:
639
+ sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
640
+
641
+ return sigmas
642
+
643
+ # https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/48a15821de768fea76e66f26df83df3fddf18f4b/modules/sd_samplers.py#L454
644
+ def get_sampler_extra_args_t2i(self, sigmas, eta, steps, func):
645
+ extra_params_kwargs = {}
646
+
647
+ if "eta" in inspect.signature(func).parameters:
648
+ extra_params_kwargs["eta"] = eta
649
+
650
+ if "sigma_min" in inspect.signature(func).parameters:
651
+ extra_params_kwargs["sigma_min"] = sigmas[0].item()
652
+ extra_params_kwargs["sigma_max"] = sigmas[-1].item()
653
+
654
+ if "n" in inspect.signature(func).parameters:
655
+ extra_params_kwargs["n"] = steps
656
+ else:
657
+ extra_params_kwargs["sigmas"] = sigmas
658
+
659
+ return extra_params_kwargs
660
+
661
+ # https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/48a15821de768fea76e66f26df83df3fddf18f4b/modules/sd_samplers.py#L454
662
+ def get_sampler_extra_args_i2i(self, sigmas, func):
663
+ extra_params_kwargs = {}
664
+
665
+ if "sigma_min" in inspect.signature(func).parameters:
666
+ ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
667
+ extra_params_kwargs["sigma_min"] = sigmas[-2]
668
+
669
+ if "sigma_max" in inspect.signature(func).parameters:
670
+ extra_params_kwargs["sigma_max"] = sigmas[0]
671
+
672
+ if "n" in inspect.signature(func).parameters:
673
+ extra_params_kwargs["n"] = len(sigmas) - 1
674
+
675
+ if "sigma_sched" in inspect.signature(func).parameters:
676
+ extra_params_kwargs["sigma_sched"] = sigmas
677
+
678
+ if "sigmas" in inspect.signature(func).parameters:
679
+ extra_params_kwargs["sigmas"] = sigmas
680
+
681
+ return extra_params_kwargs
682
+
683
+ @torch.no_grad()
684
+ def txt2img(
685
+ self,
686
+ prompt: Union[str, List[str]],
687
+ height: int = 512,
688
+ width: int = 512,
689
+ num_inference_steps: int = 50,
690
+ guidance_scale: float = 7.5,
691
+ negative_prompt: Optional[Union[str, List[str]]] = None,
692
+ eta: float = 0.0,
693
+ generator: Optional[torch.Generator] = None,
694
+ latents: Optional[torch.FloatTensor] = None,
695
+ output_type: Optional[str] = "pil",
696
+ callback_steps: Optional[int] = 1,
697
+ upscale=False,
698
+ upscale_x: float = 2.0,
699
+ upscale_method: str = "bicubic",
700
+ upscale_antialias: bool = False,
701
+ upscale_denoising_strength: int = 0.7,
702
+ pww_state=None,
703
+ pww_attn_weight=1.0,
704
+ sampler_name="",
705
+ sampler_opt={},
706
+ ):
707
+ sampler = self.get_scheduler(sampler_name)
708
+ # 1. Check inputs. Raise error if not correct
709
+ self.check_inputs(prompt, height, width, callback_steps)
710
+
711
+ # 2. Define call parameters
712
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
713
+ device = self._execution_device
714
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
715
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
716
+ # corresponds to doing no classifier free guidance.
717
+ do_classifier_free_guidance = True
718
+ if guidance_scale <= 1.0:
719
+ raise ValueError("has to use guidance_scale")
720
+
721
+ # 3. Encode input prompt
722
+ text_ids, text_embeddings = self.prompt_parser([negative_prompt, prompt])
723
+ text_embeddings = text_embeddings.to(self.unet.dtype)
724
+
725
+ # 4. Prepare timesteps
726
+ sigmas = self.get_sigmas(num_inference_steps, sampler_opt).to(
727
+ text_embeddings.device, dtype=text_embeddings.dtype
728
+ )
729
+
730
+ # 5. Prepare latent variables
731
+ num_channels_latents = self.unet.in_channels
732
+ latents = self.prepare_latents(
733
+ batch_size,
734
+ num_channels_latents,
735
+ height,
736
+ width,
737
+ text_embeddings.dtype,
738
+ device,
739
+ generator,
740
+ latents,
741
+ )
742
+ latents = latents * sigmas[0]
743
+ self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
744
+ self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(
745
+ latents.device
746
+ )
747
+
748
+ img_state = self.encode_sketchs(
749
+ pww_state,
750
+ g_strength=pww_attn_weight,
751
+ text_ids=text_ids,
752
+ )
753
+
754
+ def model_fn(x, sigma):
755
+
756
+ latent_model_input = torch.cat([x] * 2)
757
+ weight_func = (
758
+ lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max()
759
+ )
760
+ encoder_state = {
761
+ "img_state": img_state,
762
+ "states": text_embeddings,
763
+ "sigma": sigma[0],
764
+ "weight_func": weight_func,
765
+ }
766
+
767
+ noise_pred = self.k_diffusion_model(
768
+ latent_model_input, sigma, cond=encoder_state
769
+ )
770
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
771
+ noise_pred = noise_pred_uncond + guidance_scale * (
772
+ noise_pred_text - noise_pred_uncond
773
+ )
774
+ return noise_pred
775
+
776
+ extra_args = self.get_sampler_extra_args_t2i(
777
+ sigmas, eta, num_inference_steps, sampler
778
+ )
779
+ latents = sampler(model_fn, latents, **extra_args)
780
+
781
+ if upscale:
782
+ target_height = height * upscale_x
783
+ target_width = width * upscale_x
784
+ vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
785
+ latents = torch.nn.functional.interpolate(
786
+ latents,
787
+ size=(
788
+ int(target_height // vae_scale_factor),
789
+ int(target_width // vae_scale_factor),
790
+ ),
791
+ mode=upscale_method,
792
+ antialias=upscale_antialias,
793
+ )
794
+ return self.img2img(
795
+ prompt=prompt,
796
+ num_inference_steps=num_inference_steps,
797
+ guidance_scale=guidance_scale,
798
+ negative_prompt=negative_prompt,
799
+ generator=generator,
800
+ latents=latents,
801
+ strength=upscale_denoising_strength,
802
+ sampler_name=sampler_name,
803
+ sampler_opt=sampler_opt,
804
+ pww_state=None,
805
+ pww_attn_weight=pww_attn_weight/2,
806
+ )
807
+
808
+ # 8. Post-processing
809
+ image = self.decode_latents(latents)
810
+
811
+ # 10. Convert to PIL
812
+ if output_type == "pil":
813
+ image = self.numpy_to_pil(image)
814
+
815
+ return (image,)
816
+
817
+
818
+ class FlashAttentionFunction(Function):
819
+
820
+
821
+ @staticmethod
822
+ @torch.no_grad()
823
+ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
824
+ """ Algorithm 2 in the paper """
825
+
826
+ device = q.device
827
+ max_neg_value = -torch.finfo(q.dtype).max
828
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
829
+
830
+ o = torch.zeros_like(q)
831
+ all_row_sums = torch.zeros((*q.shape[:-1], 1), device = device)
832
+ all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, device = device)
833
+
834
+ scale = (q.shape[-1] ** -0.5)
835
+
836
+ if not exists(mask):
837
+ mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
838
+ else:
839
+ mask = rearrange(mask, 'b n -> b 1 1 n')
840
+ mask = mask.split(q_bucket_size, dim = -1)
841
+
842
+ row_splits = zip(
843
+ q.split(q_bucket_size, dim = -2),
844
+ o.split(q_bucket_size, dim = -2),
845
+ mask,
846
+ all_row_sums.split(q_bucket_size, dim = -2),
847
+ all_row_maxes.split(q_bucket_size, dim = -2),
848
+ )
849
+
850
+ for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
851
+ q_start_index = ind * q_bucket_size - qk_len_diff
852
+
853
+ col_splits = zip(
854
+ k.split(k_bucket_size, dim = -2),
855
+ v.split(k_bucket_size, dim = -2),
856
+ )
857
+
858
+ for k_ind, (kc, vc) in enumerate(col_splits):
859
+ k_start_index = k_ind * k_bucket_size
860
+
861
+ attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
862
+
863
+ if exists(row_mask):
864
+ attn_weights.masked_fill_(~row_mask, max_neg_value)
865
+
866
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
867
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
868
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
869
+
870
+ block_row_maxes = attn_weights.amax(dim = -1, keepdims = True)
871
+ attn_weights -= block_row_maxes
872
+ exp_weights = torch.exp(attn_weights)
873
+
874
+ if exists(row_mask):
875
+ exp_weights.masked_fill_(~row_mask, 0.)
876
+
877
+ block_row_sums = exp_weights.sum(dim = -1, keepdims = True).clamp(min = EPSILON)
878
+
879
+ new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
880
+
881
+ exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
882
+
883
+ exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
884
+ exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
885
+
886
+ new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
887
+
888
+ oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
889
+
890
+ row_maxes.copy_(new_row_maxes)
891
+ row_sums.copy_(new_row_sums)
892
+
893
+ lse = all_row_sums.log() + all_row_maxes
894
+
895
+ ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
896
+ ctx.save_for_backward(q, k, v, o, lse)
897
+
898
+ return o
899
+
900
+ @staticmethod
901
+ @torch.no_grad()
902
+ def backward(ctx, do):
903
+ """ Algorithm 4 in the paper """
904
+
905
+ causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
906
+ q, k, v, o, lse = ctx.saved_tensors
907
+
908
+ device = q.device
909
+
910
+ max_neg_value = -torch.finfo(q.dtype).max
911
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
912
+
913
+ dq = torch.zeros_like(q)
914
+ dk = torch.zeros_like(k)
915
+ dv = torch.zeros_like(v)
916
+
917
+ row_splits = zip(
918
+ q.split(q_bucket_size, dim = -2),
919
+ o.split(q_bucket_size, dim = -2),
920
+ do.split(q_bucket_size, dim = -2),
921
+ mask,
922
+ lse.split(q_bucket_size, dim = -2),
923
+ dq.split(q_bucket_size, dim = -2)
924
+ )
925
+
926
+ for ind, (qc, oc, doc, row_mask, lsec, dqc) in enumerate(row_splits):
927
+ q_start_index = ind * q_bucket_size - qk_len_diff
928
+
929
+ col_splits = zip(
930
+ k.split(k_bucket_size, dim = -2),
931
+ v.split(k_bucket_size, dim = -2),
932
+ dk.split(k_bucket_size, dim = -2),
933
+ dv.split(k_bucket_size, dim = -2),
934
+ )
935
+
936
+ for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
937
+ k_start_index = k_ind * k_bucket_size
938
+
939
+ attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
940
+
941
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
942
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
943
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
944
+
945
+ p = torch.exp(attn_weights - lsec)
946
+
947
+ if exists(row_mask):
948
+ p.masked_fill_(~row_mask, 0.)
949
+
950
+ dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
951
+ dp = einsum('... i d, ... j d -> ... i j', doc, vc)
952
+
953
+ D = (doc * oc).sum(dim = -1, keepdims = True)
954
+ ds = p * scale * (dp - D)
955
+
956
+ dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
957
+ dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
958
+
959
+ dqc.add_(dq_chunk)
960
+ dkc.add_(dk_chunk)
961
+ dvc.add_(dv_chunk)
962
+
963
+ return dq, dk, dv, None, None, None, None