mridulk commited on
Commit
6226078
1 Parent(s): 616c8f3

working on arc

Browse files
app.py CHANGED
@@ -2,17 +2,296 @@ import torch
2
  import gradio as gr
3
 
4
 
5
- def doo(text_prompt):
6
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- iface = gr.Interface(
9
- fn=doo,
10
- inputs=gr.Textbox(label="Prompt"),
11
- outputs=[
12
- gr.Image(label="Generated Image"),
13
- ]
14
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  if __name__ == "__main__":
18
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
 
4
 
5
+ import argparse, os, sys, glob
6
+ import torch
7
+ import pickle
8
+ import numpy as np
9
+ from omegaconf import OmegaConf
10
+ from PIL import Image
11
+ from tqdm import tqdm, trange
12
+ from einops import rearrange
13
+ from torchvision.utils import make_grid
14
+
15
+ from ldm.util import instantiate_from_config
16
+ from ldm.models.diffusion.ddim import DDIMSampler
17
+ from ldm.models.diffusion.plms import PLMSSampler
18
+
19
+
20
+ def load_model_from_config(config, ckpt, verbose=False):
21
+ print(f"Loading model from {ckpt}")
22
+ # pl_sd = torch.load(ckpt, map_location="cpu")
23
+ pl_sd = torch.load(ckpt)#, map_location="cpu")
24
+ sd = pl_sd["state_dict"]
25
+ model = instantiate_from_config(config.model)
26
+ m, u = model.load_state_dict(sd, strict=False)
27
+ if len(m) > 0 and verbose:
28
+ print("missing keys:")
29
+ print(m)
30
+ if len(u) > 0 and verbose:
31
+ print("unexpected keys:")
32
+ print(u)
33
+
34
+ model.cuda()
35
+ model.eval()
36
+ return model
37
+
38
+
39
+ def masking_embed(embedding, levels=1):
40
+ """
41
+ size of embedding - nx1xd, n: number of samples, d - 512
42
+ replacing the last 128*levels from the embedding
43
+ """
44
+ replace_size = 128*levels
45
+ random_noise = torch.randn(embedding.shape[0], embedding.shape[1], replace_size)
46
+ embedding[:, :, -replace_size:] = random_noise
47
+ return embedding
48
+
49
+ def generate_image(fish_name, masking_level_input,
50
+ swap_fish_name, swap_level_input):
51
+
52
+ fish_name = fish_name.lower()
53
+
54
+ ckpt_path = '/globalscratch/mridul/ldm/final_runs_eccv/fishes/2024-03-01T23-15-36_HLE_days3/checkpoints/epoch=000119.ckpt'
55
+ config_path = '/globalscratch/mridul/ldm/final_runs_eccv/fishes/2024-03-01T23-15-36_HLE_days3/configs/2024-03-01T23-15-36-project.yaml'
56
+
57
+
58
+
59
+
60
+ label_to_class_mapping = {0: 'Alosa-chrysochloris', 1: 'Carassius-auratus', 2: 'Cyprinus-carpio', 3: 'Esox-americanus',
61
+ 4: 'Gambusia-affinis', 5: 'Lepisosteus-osseus', 6: 'Lepisosteus-platostomus', 7: 'Lepomis-auritus', 8: 'Lepomis-cyanellus',
62
+ 9: 'Lepomis-gibbosus', 10: 'Lepomis-gulosus', 11: 'Lepomis-humilis', 12: 'Lepomis-macrochirus', 13: 'Lepomis-megalotis',
63
+ 14: 'Lepomis-microlophus', 15: 'Morone-chrysops', 16: 'Morone-mississippiensis', 17: 'Notropis-atherinoides',
64
+ 18: 'Notropis-blennius', 19: 'Notropis-boops', 20: 'Notropis-buccatus', 21: 'Notropis-buchanani', 22: 'Notropis-dorsalis',
65
+ 23: 'Notropis-hudsonius', 24: 'Notropis-leuciodus', 25: 'Notropis-nubilus', 26: 'Notropis-percobromus',
66
+ 27: 'Notropis-stramineus', 28: 'Notropis-telescopus', 29: 'Notropis-texanus', 30: 'Notropis-volucellus',
67
+ 31: 'Notropis-wickliffi', 32: 'Noturus-exilis', 33: 'Noturus-flavus', 34: 'Noturus-gyrinus', 35: 'Noturus-miurus',
68
+ 36: 'Noturus-nocturnus', 37: 'Phenacobius-mirabilis'}
69
+
70
+ def get_label_from_class(class_name):
71
+ for key, value in label_to_class_mapping.items():
72
+ if value == class_name:
73
+ return key
74
+
75
+ config = OmegaConf.load(config_path) # TODO: Optionally download from same location as ckpt and chnage this logic
76
+ model = load_model_from_config(config, ckpt_path) # TODO: check path
77
+
78
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
79
+ model = model.to(device)
80
+
81
+ if opt.plms:
82
+ sampler = PLMSSampler(model)
83
+ else:
84
+ sampler = DDIMSampler(model)
85
+
86
+ os.makedirs(opt.outdir, exist_ok=True)
87
+ outpath = opt.outdir
88
+
89
+ prompt = opt.prompt
90
+ all_images = []
91
+ labels = []
92
+
93
+ class_to_node = '/fastscratch/mridul/fishes/class_to_ancestral_label.pkl'
94
+ with open(class_to_node, 'rb') as pickle_file:
95
+ class_to_node_dict = pickle.load(pickle_file)
96
+
97
+ class_to_node_dict = {key.lower(): value for key, value in class_to_node_dict.items()}
98
+
99
+ sample_path = os.path.join(outpath, opt.output_dir_name)
100
+ os.makedirs(sample_path, exist_ok=True)
101
+ base_count = len(os.listdir(sample_path))
102
+
103
+ prompt = class_to_node_dict[fish_name]
104
+ if swap_fish_name:
105
+ swap_level = int(swap_level_input.split(" ")[-1]) - 1
106
+ swap_fish = class_to_node_dict[swap_fish_name]
107
+
108
+ swap_fish_split = swap_fish[0].split(',')
109
+ fish_name_split = prompt[0].split(',')
110
+ # print(swap_fish_split, fish_name_split)
111
+ # print(swap_level)
112
+ fish_name_split[swap_level] = swap_fish_split[swap_level]
113
+
114
+ prompt = [','.join(fish_name_split)]
115
+
116
+ all_samples=list()
117
+ with torch.no_grad():
118
+ with model.ema_scope():
119
+ uc = None
120
+ for n in trange(opt.n_iter, desc="Sampling"):
121
 
122
+ all_prompts = opt.n_samples * (prompt)
123
+ all_prompts = [tuple(all_prompts)]
124
+ c = model.get_learned_conditioning({'class_to_node': all_prompts})
125
+ if masking_level_input != "None":
126
+ masked_level = int(masking_level_input.split(" ")[-1])
127
+ masked_level = 4-masked_level
128
+ c = masking_embed(c, levels=masked_level)
129
+ shape = [3, 64, 64]
130
+ samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
131
+ conditioning=c,
132
+ batch_size=opt.n_samples,
133
+ shape=shape,
134
+ verbose=False,
135
+ unconditional_guidance_scale=opt.scale,
136
+ unconditional_conditioning=uc,
137
+ eta=opt.ddim_eta)
138
+
139
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
140
+ x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0)
141
+
142
+ all_samples.append(x_samples_ddim)
143
+
144
+ ###### to make grid
145
+ # additionally, save as grid
146
+ grid = torch.stack(all_samples, 0)
147
+ grid = rearrange(grid, 'n b c h w -> (n b) c h w')
148
+ grid = make_grid(grid, nrow=opt.n_samples)
149
+
150
+ # to image
151
+ grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
152
+ final_image = Image.fromarray(grid.astype(np.uint8))
153
+ # final_image.save(os.path.join(sample_path, f'{class_name.replace(" ", "-")}.png'))
154
+
155
+ return final_image
156
 
157
 
158
  if __name__ == "__main__":
159
+ parser = argparse.ArgumentParser()
160
+
161
+ parser.add_argument(
162
+ "--prompt",
163
+ type=str,
164
+ nargs="?",
165
+ default="a painting of a virus monster playing guitar",
166
+ help="the prompt to render"
167
+ )
168
+
169
+ parser.add_argument(
170
+ "--outdir",
171
+ type=str,
172
+ nargs="?",
173
+ help="dir to write results to",
174
+ default="outputs/txt2img-samples"
175
+ )
176
+ parser.add_argument(
177
+ "--ddim_steps",
178
+ type=int,
179
+ default=200,
180
+ help="number of ddim sampling steps",
181
+ )
182
+
183
+ parser.add_argument(
184
+ "--plms",
185
+ action='store_true',
186
+ help="use plms sampling",
187
+ )
188
+
189
+ parser.add_argument(
190
+ "--ddim_eta",
191
+ type=float,
192
+ default=1.0,
193
+ help="ddim eta (eta=0.0 corresponds to deterministic sampling",
194
+ )
195
+ parser.add_argument(
196
+ "--n_iter",
197
+ type=int,
198
+ default=1,
199
+ help="sample this often",
200
+ )
201
+
202
+ parser.add_argument(
203
+ "--H",
204
+ type=int,
205
+ default=256,
206
+ help="image height, in pixel space",
207
+ )
208
+
209
+ parser.add_argument(
210
+ "--W",
211
+ type=int,
212
+ default=256,
213
+ help="image width, in pixel space",
214
+ )
215
+
216
+ parser.add_argument(
217
+ "--n_samples",
218
+ type=int,
219
+ default=1,
220
+ help="how many samples to produce for the given prompt",
221
+ )
222
+
223
+ parser.add_argument(
224
+ "--output_dir_name",
225
+ type=str,
226
+ default='default_file',
227
+ help="name of folder",
228
+ )
229
+
230
+ parser.add_argument(
231
+ "--postfix",
232
+ type=str,
233
+ default='',
234
+ help="name of folder",
235
+ )
236
+
237
+ parser.add_argument(
238
+ "--scale",
239
+ type=float,
240
+ # default=5.0,
241
+ default=1.0,
242
+ help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
243
+ )
244
+ opt = parser.parse_args()
245
+
246
+
247
+ def setup_interface():
248
+ with gr.Blocks() as demo:
249
+ with gr.Row():
250
+ with gr.Column():
251
+ gr.Markdown("### Generate Images Based on Prompts")
252
+ gr.Markdown("Enter a prompt to generate an image:")
253
+ prompt_input = gr.Textbox(label="Species Name")
254
+ gr.Markdown("Trait Masking")
255
+ with gr.Row():
256
+ masking_level_input = gr.Dropdown(label="Select Ancestral Level", choices=["None", "Level 3", "Level 2"], value="None")
257
+ # masking_node_input = gr.Dropdown(label="Select Internal", choices=["0", "1", "2", "3", "4", "5", "6", "7", "8"], value="0")
258
+
259
+ gr.Markdown("Trait Swapping")
260
+ with gr.Row():
261
+ swap_fish_name = gr.Textbox(label="Species Name to swap trait with:")
262
+ swap_level_input = gr.Dropdown(label="Level of swapping", choices=["Level 3", "Level 2"], value="Level 3")
263
+ submit_button = gr.Button("Generate")
264
+ gr.Markdown("### Phylogeny Tree")
265
+ architecture_image = "phylogeny_tree.jpg" # Update this with the actual path
266
+ gr.Image(value=architecture_image, label="Phylogeny Tree")
267
+
268
+ with gr.Column():
269
+ gr.Markdown("### Generated Image")
270
+ output_image = gr.Image(label="Generated Image")
271
+
272
+ # Display an image of the architecture
273
+
274
+
275
+ submit_button.click(
276
+ fn=generate_image,
277
+ inputs=[prompt_input, masking_level_input,
278
+ swap_fish_name, swap_level_input],
279
+ outputs=output_image
280
+ )
281
+
282
+ return demo
283
+
284
+ # # Launch the interface
285
+ # iface = setup_interface()
286
+
287
+ # iface = gr.Interface(
288
+ # fn=generate_image,
289
+ # inputs=gr.Textbox(label="Prompt"),
290
+ # outputs=[
291
+ # gr.Image(label="Generated Image"),
292
+ # ]
293
+ # )
294
+
295
+ iface = setup_interface()
296
+
297
+ iface.launch(share=True)
ldm/models/__pycache__/autoencoder.cpython-38.pyc CHANGED
Binary files a/ldm/models/__pycache__/autoencoder.cpython-38.pyc and b/ldm/models/__pycache__/autoencoder.cpython-38.pyc differ
 
ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc CHANGED
Binary files a/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc and b/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc differ
 
ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc CHANGED
Binary files a/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc and b/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc differ
 
ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc CHANGED
Binary files a/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc and b/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc differ
 
ldm/models/diffusion/__pycache__/plms.cpython-38.pyc CHANGED
Binary files a/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc and b/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc differ
 
ldm/models/disentanglement/__pycache__/iterative_normalization.cpython-38.pyc CHANGED
Binary files a/ldm/models/disentanglement/__pycache__/iterative_normalization.cpython-38.pyc and b/ldm/models/disentanglement/__pycache__/iterative_normalization.cpython-38.pyc differ
 
ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc CHANGED
Binary files a/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc and b/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc differ
 
ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc CHANGED
Binary files a/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc and b/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc differ
 
ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc CHANGED
Binary files a/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc and b/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc differ
 
ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc CHANGED
Binary files a/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc and b/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc differ
 
ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc CHANGED
Binary files a/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc and b/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc differ
 
ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc CHANGED
Binary files a/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc and b/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc differ
 
ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc CHANGED
Binary files a/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc and b/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc differ
 
ldm/modules/encoders/__pycache__/modules.cpython-38.pyc CHANGED
Binary files a/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc and b/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc differ
 
phylogeny_tree.jpg ADDED
phylogeny_tree.pdf ADDED
Binary file (236 kB). View file
 
phylogeny_tree.png ADDED
sample_level_encoding.py CHANGED
@@ -32,6 +32,16 @@ def load_model_from_config(config, ckpt, verbose=False):
32
  model.eval()
33
  return model
34
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  if __name__ == "__main__":
37
  parser = argparse.ArgumentParser()
@@ -206,6 +216,9 @@ if __name__ == "__main__":
206
  with open(class_to_node, 'rb') as pickle_file:
207
  class_to_node_dict = pickle.load(pickle_file)
208
 
 
 
 
209
  sample_path = os.path.join(outpath, opt.output_dir_name)
210
  os.makedirs(sample_path, exist_ok=True)
211
  base_count = len(os.listdir(sample_path))
@@ -223,6 +236,7 @@ if __name__ == "__main__":
223
  all_prompts = opt.n_samples * (prompt)
224
  all_prompts = [tuple(all_prompts)]
225
  print(class_name, prompt)
 
226
  c = model.get_learned_conditioning({'class_to_node': all_prompts})
227
  shape = [3, 64, 64]
228
  samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
 
32
  model.eval()
33
  return model
34
 
35
+ def masking_embed(embedding, levels=1):
36
+ """
37
+ size of embedding - nx1xd, n: number of samples, d - 512
38
+ replacing the last 128*levels from the embedding
39
+ """
40
+ replace_size = 128*levels
41
+ random_noise = torch.randn(embedding.shape[0], embedding.shape[1], replace_size)
42
+ embedding[:, :, -replace_size:] = random_noise
43
+ return embedding
44
+
45
 
46
  if __name__ == "__main__":
47
  parser = argparse.ArgumentParser()
 
216
  with open(class_to_node, 'rb') as pickle_file:
217
  class_to_node_dict = pickle.load(pickle_file)
218
 
219
+ class_to_node_dict = {key.lower(): value for key, value in class_to_node_dict.items()}
220
+
221
+
222
  sample_path = os.path.join(outpath, opt.output_dir_name)
223
  os.makedirs(sample_path, exist_ok=True)
224
  base_count = len(os.listdir(sample_path))
 
236
  all_prompts = opt.n_samples * (prompt)
237
  all_prompts = [tuple(all_prompts)]
238
  print(class_name, prompt)
239
+ breakpoint()
240
  c = model.get_learned_conditioning({'class_to_node': all_prompts})
241
  shape = [3, 64, 64]
242
  samples_ddim, _ = sampler.sample(S=opt.ddim_steps,