menghanxia commited on
Commit
b3640b9
1 Parent(s): 4636eb6

upload whole project

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Menghan Xia
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py CHANGED
@@ -1,7 +1,92 @@
1
  import gradio as gr
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os, requests
3
+ from inference import setup_model, colorize_grayscale, predict_anchors
4
 
5
+ ## download checkpoint
6
+ def download_file_from_google_drive(id, destination):
7
+ def get_confirm_token(response):
8
+ for key, value in response.cookies.items():
9
+ if key.startswith('download_warning'):
10
+ return value
11
+ return None
12
 
13
+ def save_response_content(response, destination):
14
+ CHUNK_SIZE = 32768
15
+ with open(destination, "wb") as f:
16
+ for chunk in response.iter_content(CHUNK_SIZE):
17
+ if chunk: # filter out keep-alive new chunks
18
+ f.write(chunk)
19
+
20
+ URL = "https://docs.google.com/uc?export=download"
21
+ session = requests.Session()
22
+ response = session.get(URL, params = { 'id' : id }, stream = True)
23
+ token = get_confirm_token(response)
24
+
25
+ if token:
26
+ params = { 'id' : id, 'confirm' : token }
27
+ response = session.get(URL, params = params, stream = True)
28
+ save_response_content(response, destination)
29
+
30
+ id = "1J4vB6kG4xBLUUKpXr5IhnSSa4maXgRvQ"
31
+ destination = "disco-beta.pth.rar"
32
+ download_file_from_google_drive(id, destination)
33
+ os.rename("disco-beta.pth.tar", "./checkpoints/disco-beta.pth.tar")
34
+
35
+ ## step 1: set up model
36
+ device = "cuda"
37
+ checkpt_path = "./checkpoints/disco-beta.pth.tar"
38
+ assert os.path.exists(checkpt_path), "No checkpoint found!"
39
+ colorizer, colorLabeler = setup_model(checkpt_path, device=device)
40
+
41
+ def click_colorize(rgb_img, hint_img, n_anchors, is_high_res, is_editable):
42
+ if hint_img is None:
43
+ hint_img = rgb_img
44
+ output = colorize_grayscale(colorizer, colorLabeler, rgb_img, hint_img, n_anchors, is_high_res, is_editable, device)
45
+ return output
46
+
47
+ def click_predanchors(rgb_img, n_anchors, is_high_res, is_editable):
48
+ output = predict_anchors(colorizer, colorLabeler, rgb_img, n_anchors, is_high_res, is_editable, device)
49
+ return output
50
+
51
+ ## step 2: configure interface
52
+ def switch_states(is_checked):
53
+ if is_checked:
54
+ return gr.Image.update(visible=True), gr.Button.update(visible=True)
55
+ else:
56
+ return gr.Image.update(visible=False), gr.Button.update(visible=False)
57
+
58
+ demo = gr.Blocks(title="DISCO: Image Colorization")
59
+ with demo:
60
+ gr.Markdown(value="""**DISCO: image colorization that disentangles color multimodality and spatial affinity via global anchors**.""")
61
+ with gr.Row():
62
+ with gr.Column(scale=1):
63
+ Image_input = gr.Image(type="numpy", label="Input", interactive=True)
64
+ Image_anchor = gr.Image(type="numpy", label="Anchor", tool="color-sketch", interactive=True, visible=False)
65
+ with gr.Row():
66
+ Num_anchor = gr.Number(type="int", value=8, label="Num. of anchors (3~14)")
67
+ Radio_resolution = gr.Radio(type="index", choices=["Low (256x256)", "High (512x512)"], \
68
+ label="Colorization resolution", value="Low (256x256)")
69
+ Ckeckbox_editable = gr.Checkbox(default=False, label='Show editable anchors')
70
+ with gr.Row():
71
+ Button_show_anchor = gr.Button(value="Predict anchors", visible=False)
72
+ Button_run = gr.Button(value="Colorize")
73
+ with gr.Column(scale=1):
74
+ Image_output = gr.Image(type="numpy", label="Output", shape=[100,100])
75
+
76
+ Ckeckbox_editable.change(fn=switch_states, inputs=Ckeckbox_editable, outputs=[Image_anchor, Button_show_anchor])
77
+ Button_show_anchor.click(fn=click_predanchors, inputs=[Image_input, Num_anchor, Radio_resolution, Ckeckbox_editable], outputs=Image_anchor)
78
+ Button_run.click(fn=click_colorize, inputs=[Image_input, Image_anchor, Num_anchor, Radio_resolution, Ckeckbox_editable], \
79
+ outputs=Image_output)
80
+ ## guiline
81
+ gr.Markdown(value="""
82
+ **Guideline**
83
+ 1. Upload your image;
84
+ 2. Set up the arguments: "Num. of anchors" and "Colorization resolution";
85
+ 3. Two modes are supported:
86
+ - **Editable**: check ""Show editable anchors" and click "Predict anchors". Then, modify the colors of the predicted anchors (anchor mask will be applied afterward). Finally, click "Colorize" to get the result.
87
+ - **Automatic**: click "Colorize" to get the automatically colorized output.
88
+
89
+ *To know more about the method, please refer to our project page: [https://menghanxia.github.io/projects/disco.html](https://menghanxia.github.io/projects/disco.html)*
90
+ """)
91
+
92
+ demo.launch(server_name='9.134.253.83',server_port=7788)
checkpoints/disco_download.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1J4vB6kG4xBLUUKpXr5IhnSSa4maXgRvQ' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1J4vB6kG4xBLUUKpXr5IhnSSa4maXgRvQ" -O disco-beta.pth.tar && rm -rf /tmp/cookies.txt
cog.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for Cog ⚙️
2
+ # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
3
+
4
+ build:
5
+ # set to true if your model requires a GPU
6
+ cuda: "10.2"
7
+ gpu: true
8
+
9
+ # a list of ubuntu apt packages to install
10
+ system_packages:
11
+ # - "libgl1-mesa-glx"
12
+ # - "libglib2.0-0"
13
+ - "libgl1-mesa-dev"
14
+
15
+ # python version in the form '3.8' or '3.8.12'
16
+ python_version: "3.8"
17
+
18
+ # a list of packages in the format <package-name>==<version>
19
+ python_packages:
20
+ # - "numpy==1.19.4"
21
+ # - "torch==1.8.0"
22
+ # - "torchvision==0.9.0"
23
+ - "numpy==1.23.1"
24
+ - "torch==1.8.0"
25
+ - "torchvision==0.9.0"
26
+ - "opencv-python==4.6.0.66"
27
+ - "pandas==1.4.3"
28
+ - "pillow==9.2.0"
29
+ - "tqdm==4.64.0"
30
+ - "scikit-image==0.19.3"
31
+ - "scikit-learn==1.1.2"
32
+ - "scipy==1.9.1"
33
+
34
+ # commands run after the environment is setup
35
+ # run:
36
+ # - "echo env is ready!"
37
+ # - "echo another command if needed"
38
+
39
+ # predict.py defines how predictions are run on your model
40
+ predict: "predict.py:Predictor"
41
+ #image: "r8.im/menghanxia/disco"
environment.yml ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: DISCO
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ - conda-forge
6
+ dependencies:
7
+ - blas=1.0=mkl
8
+ - bzip2=1.0.8=h7b6447c_0
9
+ - ca-certificates=2022.07.19=h06a4308_0
10
+ - certifi=2022.6.15=py38h06a4308_0
11
+ - cudatoolkit=10.2.89=hfd86e86_1
12
+ - freetype=2.11.0=h70c0345_0
13
+ - giflib=5.2.1=h7b6447c_0
14
+ - gmp=6.2.1=h295c915_3
15
+ - gnutls=3.6.15=he1e5248_0
16
+ - intel-openmp=2021.4.0=h06a4308_3561
17
+ - jpeg=9b=h024ee3a_2
18
+ - lame=3.100=h7b6447c_0
19
+ - lcms2=2.12=h3be6417_0
20
+ - ld_impl_linux-64=2.38=h1181459_1
21
+ - libffi=3.3=he6710b0_2
22
+ - libgcc-ng=11.2.0=h1234567_1
23
+ - libiconv=1.16=h7f8727e_2
24
+ - libidn2=2.3.2=h7f8727e_0
25
+ - libpng=1.6.37=hbc83047_0
26
+ - libstdcxx-ng=11.2.0=h1234567_1
27
+ - libtasn1=4.16.0=h27cfd23_0
28
+ - libtiff=4.1.0=h2733197_1
29
+ - libunistring=0.9.10=h27cfd23_0
30
+ - libuv=1.40.0=h7b6447c_0
31
+ - libwebp=1.2.0=h89dd481_0
32
+ - lz4-c=1.9.3=h295c915_1
33
+ - mkl=2021.4.0=h06a4308_640
34
+ - mkl-service=2.4.0=py38h7f8727e_0
35
+ - mkl_fft=1.3.1=py38hd3c417c_0
36
+ - mkl_random=1.2.2=py38h51133e4_0
37
+ - ncurses=6.3=h5eee18b_3
38
+ - nettle=3.7.3=hbbd107a_1
39
+ - ninja=1.10.2=h06a4308_5
40
+ - ninja-base=1.10.2=hd09550d_5
41
+ - numpy=1.23.1=py38h6c91a56_0
42
+ - numpy-base=1.23.1=py38ha15fc14_0
43
+ - openh264=2.1.1=h4ff587b_0
44
+ - openssl=1.1.1q=h7f8727e_0
45
+ - pillow=9.2.0=py38hace64e9_1
46
+ - pip=22.1.2=py38h06a4308_0
47
+ - python=3.8.13=h12debd9_0
48
+ - readline=8.1.2=h7f8727e_1
49
+ - setuptools=63.4.1=py38h06a4308_0
50
+ - six=1.16.0=pyhd3eb1b0_1
51
+ - sqlite=3.39.2=h5082296_0
52
+ - tk=8.6.12=h1ccaba5_0
53
+ - typing_extensions=4.3.0=py38h06a4308_0
54
+ - wheel=0.37.1=pyhd3eb1b0_0
55
+ - xz=5.2.5=h7f8727e_1
56
+ - zlib=1.2.12=h7f8727e_2
57
+ - zstd=1.4.9=haebb681_0
58
+ - ffmpeg=4.3=hf484d3e_0
59
+ - pytorch=1.8.0=py3.8_cuda10.2_cudnn7.6.5_0
60
+ - torchaudio=0.8.0=py38
61
+ - torchvision=0.9.0=py38_cu102
62
+ - pip:
63
+ - addict==2.4.0
64
+ - astunparse==1.6.3
65
+ - cachetools==4.2.4
66
+ - charset-normalizer==2.0.7
67
+ - clang==5.0
68
+ - cycler==0.11.0
69
+ - flatbuffers==1.12
70
+ - fonttools==4.37.1
71
+ - future==0.18.2
72
+ - gast==0.4.0
73
+ - google-auth==2.3.2
74
+ - google-auth-oauthlib==0.4.6
75
+ - google-pasta==0.2.0
76
+ - grpcio==1.41.1
77
+ - h5py==3.1.0
78
+ - idna==3.3
79
+ - imageio==2.21.1
80
+ - joblib==1.1.0
81
+ - keras==2.6.0
82
+ - keras-preprocessing==1.1.2
83
+ - kiwisolver==1.4.4
84
+ - lpips==0.1.4
85
+ - markdown==3.3.4
86
+ - matplotlib==3.5.3
87
+ - networkx==2.8.6
88
+ - oauthlib==3.1.1
89
+ - opencv-python==4.6.0.66
90
+ - opt-einsum==3.3.0
91
+ - packaging==21.3
92
+ - pandas==1.4.3
93
+ - protobuf==3.19.0
94
+ - pyasn1==0.4.8
95
+ - pyasn1-modules==0.2.8
96
+ - pyparsing==3.0.9
97
+ - python-dateutil==2.8.2
98
+ - pytz==2022.2.1
99
+ - pywavelets==1.3.0
100
+ - pyyaml==6.0
101
+ - requests==2.26.0
102
+ - requests-oauthlib==1.3.0
103
+ - rsa==4.7.2
104
+ - scikit-image==0.19.3
105
+ - scikit-learn==1.1.2
106
+ - scipy==1.9.1
107
+ - tensorboard-data-server==0.6.1
108
+ - tensorboard-plugin-wit==1.8.0
109
+ - tensorflow-estimator==2.6.0
110
+ - tensorflow-gpu==2.6.0
111
+ - termcolor==1.1.0
112
+ - threadpoolctl==3.1.0
113
+ - tifffile==2022.8.12
114
+ - torch==1.8.0
115
+ - tqdm==4.64.0
116
+ - urllib3==1.26.7
117
+ - werkzeug==2.0.2
118
+ - wrapt==1.12.1
119
+ - yapf==0.32.0
120
+ prefix: /root/data/programs/anaconda3/envs/DISCO
121
+
inference.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, glob, sys, logging
2
+ import argparse, datetime, time
3
+ import numpy as np
4
+ import cv2
5
+ from PIL import Image
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from models import model, basic
10
+ from utils import util
11
+
12
+
13
+ def setup_model(checkpt_path, device="cuda"):
14
+ seed = 130
15
+ np.random.seed(seed)
16
+ torch.manual_seed(seed)
17
+ torch.cuda.manual_seed(seed)
18
+ #print('--------------', torch.cuda.is_available())
19
+ """Load the model into memory to make running multiple predictions efficient"""
20
+ colorLabeler = basic.ColorLabel(device=device)
21
+ colorizer = model.AnchorColorProb(inChannel=1, outChannel=313, enhanced=True, colorLabeler=colorLabeler)
22
+ colorizer = colorizer.to(device)
23
+ #checkpt_path = "./checkpoints/disco-beta.pth.rar"
24
+ assert os.path.exists(checkpt_path)
25
+ data_dict = torch.load(checkpt_path, map_location=torch.device('cpu'))
26
+ colorizer.load_state_dict(data_dict['state_dict'])
27
+ colorizer.eval()
28
+ return colorizer, colorLabeler
29
+
30
+
31
+ def resize_ab2l(gray_img, lab_imgs, vis=False):
32
+ H, W = gray_img.shape[:2]
33
+ reszied_ab = cv2.resize(lab_imgs[:,:,1:], (W,H), interpolation=cv2.INTER_LINEAR)
34
+ if vis:
35
+ gray_img = cv2.resize(lab_imgs[:,:,:1], (W,H), interpolation=cv2.INTER_LINEAR)
36
+ return np.concatenate((gray_img[:,:,np.newaxis], reszied_ab), axis=2)
37
+ else:
38
+ return np.concatenate((gray_img, reszied_ab), axis=2)
39
+
40
+ def prepare_data(rgb_img, target_res):
41
+ rgb_img = np.array(rgb_img / 255., np.float32)
42
+ lab_img = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2LAB)
43
+ org_grays = (lab_img[:,:,[0]]-50.) / 50.
44
+ lab_img = cv2.resize(lab_img, target_res, interpolation=cv2.INTER_LINEAR)
45
+
46
+ lab_img = torch.from_numpy(lab_img.transpose((2, 0, 1)))
47
+ gray_img = (lab_img[0:1,:,:]-50.) / 50.
48
+ ab_chans = lab_img[1:3,:,:] / 110.
49
+ input_grays = gray_img.unsqueeze(0)
50
+ input_colors = ab_chans.unsqueeze(0)
51
+ return input_grays, input_colors, org_grays
52
+
53
+
54
+ def colorize_grayscale(colorizer, color_class, rgb_img, hint_img, n_anchors, is_high_res, is_editable, device="cuda"):
55
+ n_anchors = int(n_anchors)
56
+ n_anchors = max(n_anchors, 3)
57
+ n_anchors = min(n_anchors, 14)
58
+ target_res = (512,512) if is_high_res else (256,256)
59
+ input_grays, input_colors, org_grays = prepare_data(rgb_img, target_res)
60
+ input_grays = input_grays.to(device)
61
+ input_colors = input_colors.to(device)
62
+
63
+ if is_editable:
64
+ print('>>>:editable mode')
65
+ sampled_T = -1
66
+ _, input_colors, _ = prepare_data(hint_img, target_res)
67
+ input_colors = input_colors.to(device)
68
+ pal_logit, ref_logit, enhanced_ab, affinity_map, spix_colors, hint_mask = colorizer(input_grays, \
69
+ input_colors, n_anchors, sampled_T)
70
+ else:
71
+ print('>>>:automatic mode')
72
+ sampled_T = 0
73
+ pal_logit, ref_logit, enhanced_ab, affinity_map, spix_colors, hint_mask = colorizer(input_grays, \
74
+ input_colors, n_anchors, sampled_T)
75
+
76
+ pred_labs = torch.cat((input_grays,enhanced_ab), dim=1)
77
+ lab_imgs = basic.tensor2array(pred_labs).squeeze(axis=0)
78
+ lab_imgs = resize_ab2l(org_grays, lab_imgs)
79
+
80
+ lab_imgs[:,:,0] = lab_imgs[:,:,0] * 50.0 + 50.0
81
+ lab_imgs[:,:,1:3] = lab_imgs[:,:,1:3] * 110.0
82
+ rgb_output = cv2.cvtColor(lab_imgs[:,:,:], cv2.COLOR_LAB2RGB)
83
+ return (rgb_output*255.0).astype(np.uint8)
84
+
85
+
86
+ def predict_anchors(colorizer, color_class, rgb_img, n_anchors, is_high_res, is_editable, device="cuda"):
87
+ n_anchors = int(n_anchors)
88
+ n_anchors = max(n_anchors, 3)
89
+ n_anchors = min(n_anchors, 14)
90
+ target_res = (512,512) if is_high_res else (256,256)
91
+ input_grays, input_colors, org_grays = prepare_data(rgb_img, target_res)
92
+ input_grays = input_grays.cuda(non_blocking=True)
93
+ input_colors = input_colors.cuda(non_blocking=True)
94
+
95
+ sampled_T, sp_size = 0, 16
96
+ pal_logit, ref_logit, enhanced_ab, affinity_map, spix_colors, hint_mask = colorizer(input_grays, \
97
+ input_colors, n_anchors, sampled_T)
98
+ pred_probs = pal_logit
99
+ guided_colors = color_class.decode_ind2ab(ref_logit, T=0)
100
+ guided_colors = basic.upfeat(guided_colors, affinity_map, sp_size, sp_size)
101
+ anchor_masks = basic.upfeat(hint_mask, affinity_map, sp_size, sp_size)
102
+ marked_labs = basic.mark_color_hints(input_grays, guided_colors, anchor_masks, base_ABs=None)
103
+ lab_imgs = basic.tensor2array(marked_labs).squeeze(axis=0)
104
+ lab_imgs = resize_ab2l(org_grays, lab_imgs, vis=True)
105
+
106
+ lab_imgs[:,:,0] = lab_imgs[:,:,0] * 50.0 + 50.0
107
+ lab_imgs[:,:,1:3] = lab_imgs[:,:,1:3] * 110.0
108
+ rgb_output = cv2.cvtColor(lab_imgs[:,:,:], cv2.COLOR_LAB2RGB)
109
+ return (rgb_output*255.0).astype(np.uint8)
models/__init__.py ADDED
File without changes
models/anchor_gen.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.autograd import Function
5
+ from models import basic, clusterkit
6
+ import pdb
7
+
8
+ class AnchorAnalysis:
9
+ def __init__(self, mode, colorLabeler):
10
+ ## anchor generating mode: 1.random; 2.clustering
11
+ self.mode = mode
12
+ self.colorLabeler = colorLabeler
13
+
14
+ def _detect_correlation(self, data_tensors, color_probs, hint_masks, thres=0.1):
15
+ N,C,H,W = data_tensors.shape
16
+ ## (N,C,HW)
17
+ data_vecs = data_tensors.flatten(2)
18
+ prob_vecs = color_probs.flatten(2)
19
+ mask_vecs = hint_masks.flatten(2)
20
+ #anchor_data = torch.masked_select(data_vecs, mask_vecs.bool()).view(N,C,-1)
21
+ #anchor_prob = torch.masked_select(prob_vecs, mask_vecs.bool()).view(N,313,-1)
22
+ #_,_,K = anchor_data.shape
23
+ anchor_mask = torch.matmul(mask_vecs.permute(0,2,1), mask_vecs)
24
+ cosine_sim = True
25
+ ## non-similarity matrix
26
+ if cosine_sim:
27
+ norm_data = F.normalize(data_vecs, p=2, dim=1)
28
+ ## (N,HW,HW) = (N,HW,C) X (N,C,HW)
29
+ corr_matrix = torch.matmul(norm_data.permute(0,2,1), norm_data)
30
+ ## remapping: [-1.0,1.0] to [0.0,1.0], and convert into dis-similarity
31
+ dist_matrix = 1.0 - 0.5*(corr_matrix + 1.0)
32
+ else:
33
+ ## (N,HW,HW) = (N,HW,C) X (N,C,HW)
34
+ XtX = torch.matmul(data_vecs.permute(0,2,1), data_vecs)
35
+ diag_vec = torch.diagonal(XtX, dim1=-2, dim2=-1)
36
+ A = diag_vec.unsqueeze(1).repeat(1,H*W,1)
37
+ At = diag_vec.unsqueeze(2).repeat(1,1,H*W)
38
+ dist_matrix = A - 2*XtX + At
39
+ #dist_matrix = dist_matrix + 1e7*torch.eye(K).to(data_tensors.device).repeat(N,1,1)
40
+ ## for debug use
41
+ K = 8
42
+ anchor_adj_matrix = torch.masked_select(dist_matrix, anchor_mask.bool()).view(N,K,K)
43
+ ## dectect connected nodes
44
+ adj_matrix = torch.where((dist_matrix < thres) & (anchor_mask > 0), torch.ones_like(dist_matrix), torch.zeros_like(dist_matrix))
45
+ adj_matrix = torch.matmul(adj_matrix, adj_matrix)
46
+ adj_matrix = adj_matrix / (1e-7+adj_matrix)
47
+ ## merge nodes
48
+ ## (N,K,C) = (N,K,K) X (N,K,C)
49
+ anchor_prob = torch.matmul(adj_matrix, prob_vecs.permute(0,2,1)) / torch.sum(adj_matrix, dim=2, keepdim=True)
50
+ updated_prob_vecs = anchor_prob.permute(0,2,1) * mask_vecs + (1-mask_vecs) * prob_vecs
51
+ color_probs = updated_prob_vecs.view(N,313,H,W)
52
+ return color_probs, anchor_adj_matrix
53
+
54
+ def _sample_anchor_colors(self, pred_prob, hint_mask, T=0):
55
+ N,C,H,W = pred_prob.shape
56
+ topk = 10
57
+ assert T < topk
58
+ sorted_probs, batch_indexs = torch.sort(pred_prob, dim=1, descending=True)
59
+ ## (N,topk,H,W,1)
60
+ topk_probs = torch.softmax(sorted_probs[:,:topk,:,:], dim=1).unsqueeze(4)
61
+ topk_indexs = batch_indexs[:,:topk,:,:]
62
+ topk_ABs = torch.stack([self.colorLabeler.q_to_ab.index_select(0, q_i.flatten()).reshape(topk,H,W,2)
63
+ for q_i in topk_indexs])
64
+ ## (N,topk,H,W,2)
65
+ topk_ABs = topk_ABs / 110.0
66
+ ## choose the most distinctive 3 colors for each anchor
67
+ if T == 0:
68
+ sampled_ABs = topk_ABs[:,0,:,:,:]
69
+ elif T == 1:
70
+ sampled_AB0 = topk_ABs[:,[0],:,:,:]
71
+ internal_diff = torch.norm(topk_ABs-sampled_AB0, p=2, dim=4, keepdim=True)
72
+ _, batch_indexs = torch.sort(internal_diff, dim=1, descending=True)
73
+ ## (N,1,H,W,2)
74
+ selected_index = batch_indexs[:,[0],:,:,:].expand([-1,-1,-1,-1,2])
75
+ sampled_ABs = torch.gather(topk_ABs, 1, selected_index)
76
+ sampled_ABs = sampled_ABs.squeeze(1)
77
+ else:
78
+ sampled_AB0 = topk_ABs[:,[0],:,:,:]
79
+ internal_diff = torch.norm(topk_ABs-sampled_AB0, p=2, dim=4, keepdim=True)
80
+ _, batch_indexs = torch.sort(internal_diff, dim=1, descending=True)
81
+ selected_index = batch_indexs[:,[0],:,:,:].expand([-1,-1,-1,-1,2])
82
+ sampled_AB1 = torch.gather(topk_ABs, 1, selected_index)
83
+ internal_diff2 = torch.norm(topk_ABs-sampled_AB1, p=2, dim=4, keepdim=True)
84
+ _, batch_indexs = torch.sort(internal_diff+internal_diff2, dim=1, descending=True)
85
+ ## (N,1,H,W,2)
86
+ selected_index = batch_indexs[:,[T-2],:,:,:].expand([-1,-1,-1,-1,2])
87
+ sampled_ABs = torch.gather(topk_ABs, 1, selected_index)
88
+ sampled_ABs = sampled_ABs.squeeze(1)
89
+
90
+ return sampled_ABs.permute(0,3,1,2)
91
+
92
+ def __call__(self, data_tensors, n_anchors, spixel_sizes, use_sklearn_kmeans=False):
93
+ N,C,H,W = data_tensors.shape
94
+ if self.mode == 'clustering':
95
+ ## clusters map: (N,K,H,W)
96
+ cluster_mask = clusterkit.batch_kmeans_pytorch(data_tensors, n_anchors, 'euclidean', use_sklearn_kmeans)
97
+ #noises = torch.rand(N,1,H,W).to(cluster_mask.device)
98
+ perturb_factors = spixel_sizes
99
+ cluster_prob = cluster_mask + perturb_factors * 0.01
100
+ hint_mask_layers = F.one_hot(torch.argmax(cluster_prob.flatten(2), dim=-1), num_classes=H*W).float()
101
+ hint_mask = torch.sum(hint_mask_layers, dim=1, keepdim=True).view(N,1,H,W)
102
+ else:
103
+ #print('----------hello, random!')
104
+ cluster_mask = torch.zeros(N,n_anchors,H,W).to(data_tensors.device)
105
+ binary_mask = basic.get_random_mask(N, H, W, minNum=n_anchors, maxNum=n_anchors)
106
+ hint_mask = torch.from_numpy(binary_mask).to(data_tensors.device)
107
+ return hint_mask, cluster_mask
models/basic.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.nn.utils.spectral_norm as spectral_norm
6
+ from torch.autograd import Function
7
+ from utils import util, cielab
8
+ import cv2, math, random
9
+
10
+ def tensor2array(tensors):
11
+ arrays = tensors.detach().to("cpu").numpy()
12
+ return np.transpose(arrays, (0, 2, 3, 1))
13
+
14
+
15
+ def rgb2gray(color_batch):
16
+ #! gray = 0.299*R+0.587*G+0.114*B
17
+ gray_batch = color_batch[:, 0, ...] * 0.299 + color_batch[:, 1, ...] * 0.587 + color_batch[:, 2, ...] * 0.114
18
+ gray_batch = gray_batch.unsqueeze_(1)
19
+ return gray_batch
20
+
21
+
22
+ def getParamsAmount(model):
23
+ params = list(model.parameters())
24
+ count = 0
25
+ for var in params:
26
+ l = 1
27
+ for j in var.size():
28
+ l *= j
29
+ count += l
30
+ return count
31
+
32
+
33
+ def checkAverageGradient(model):
34
+ meanGrad, cnt = 0.0, 0
35
+ for name, parms in model.named_parameters():
36
+ if parms.requires_grad:
37
+ meanGrad += torch.mean(torch.abs(parms.grad))
38
+ cnt += 1
39
+ return meanGrad.item() / cnt
40
+
41
+
42
+ def get_random_mask(N, H, W, minNum, maxNum):
43
+ binary_maps = np.zeros((N, H*W), np.float32)
44
+ for i in range(N):
45
+ locs = random.sample(range(0, H*W), random.randint(minNum,maxNum))
46
+ binary_maps[i, locs] = 1
47
+ return binary_maps.reshape(N,1,H,W)
48
+
49
+
50
+ def io_user_control(hint_mask, spix_colors, output=True):
51
+ cache_dir = '/apdcephfs/private_richardxia'
52
+ if output:
53
+ print('--- data saving')
54
+ mask_imgs = tensor2array(hint_mask) * 2.0 - 1.0
55
+ util.save_images_from_batch(mask_imgs, cache_dir, ['mask.png'], -1)
56
+ fake_gray = torch.zeros_like(spix_colors[:,[0],:,:])
57
+ spix_labs = torch.cat((fake_gray,spix_colors), dim=1)
58
+ spix_imgs = tensor2array(spix_labs)
59
+ util.save_normLabs_from_batch(spix_imgs, cache_dir, ['color.png'], -1)
60
+ return hint_mask, spix_colors
61
+ else:
62
+ print('--- data loading')
63
+ mask_img = cv2.imread(cache_dir+'/mask.png', cv2.IMREAD_GRAYSCALE)
64
+ mask_img = np.expand_dims(mask_img, axis=2) / 255.
65
+ hint_mask = torch.from_numpy(mask_img.transpose((2, 0, 1)))
66
+ hint_mask = hint_mask.unsqueeze(0).cuda()
67
+ bgr_img = cv2.imread(cache_dir+'/color.png', cv2.IMREAD_COLOR)
68
+ rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
69
+ rgb_img = np.array(rgb_img / 255., np.float32)
70
+ lab_img = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2LAB)
71
+ lab_img = torch.from_numpy(lab_img.transpose((2, 0, 1)))
72
+ ab_chans = lab_img[1:3,:,:] / 110.
73
+ spix_colors = ab_chans.unsqueeze(0).cuda()
74
+ return hint_mask.float(), spix_colors.float()
75
+
76
+
77
+ class Quantize(Function):
78
+ @staticmethod
79
+ def forward(ctx, x):
80
+ ctx.save_for_backward(x)
81
+ y = x.round()
82
+ return y
83
+
84
+ @staticmethod
85
+ def backward(ctx, grad_output):
86
+ """
87
+ In the backward pass we receive a Tensor containing the gradient of the loss
88
+ with respect to the output, and we need to compute the gradient of the loss
89
+ with respect to the input.
90
+ """
91
+ inputX = ctx.saved_tensors
92
+ return grad_output
93
+
94
+
95
+ def mark_color_hints(input_grays, target_ABs, gate_maps, kernel_size=3, base_ABs=None):
96
+ ## to highlight the seeds with 1-pixel margin
97
+ binary_map = torch.where(gate_maps>0.7, torch.ones_like(gate_maps), torch.zeros_like(gate_maps))
98
+ center_mask = dilate_seeds(binary_map, kernel_size=kernel_size)
99
+ margin_mask = dilate_seeds(binary_map, kernel_size=kernel_size+2) - center_mask
100
+ ## drop colors
101
+ dilated_seeds = dilate_seeds(gate_maps, kernel_size=kernel_size+2)
102
+ marked_grays = torch.where(margin_mask > 1e-5, torch.ones_like(gate_maps), input_grays)
103
+ if base_ABs is None:
104
+ marked_ABs = torch.where(center_mask < 1e-5, torch.zeros_like(target_ABs), target_ABs)
105
+ else:
106
+ marked_ABs = torch.where(margin_mask > 1e-5, torch.zeros_like(base_ABs), base_ABs)
107
+ marked_ABs = torch.where(center_mask > 1e-5, target_ABs, marked_ABs)
108
+ return torch.cat((marked_grays,marked_ABs), dim=1)
109
+
110
+ def dilate_seeds(gate_maps, kernel_size=3):
111
+ N,C,H,W = gate_maps.shape
112
+ input_unf = F.unfold(gate_maps, kernel_size, padding=kernel_size//2)
113
+ #! Notice: differentiable? just like max pooling?
114
+ dilated_seeds, _ = torch.max(input_unf, dim=1, keepdim=True)
115
+ output = F.fold(dilated_seeds, output_size=(H,W), kernel_size=1)
116
+ #print('-------', input_unf.shape)
117
+ return output
118
+
119
+
120
+ class RebalanceLoss(Function):
121
+ @staticmethod
122
+ def forward(ctx, data_input, weights):
123
+ ctx.save_for_backward(weights)
124
+ return data_input.clone()
125
+
126
+ @staticmethod
127
+ def backward(ctx, grad_output):
128
+ weights, = ctx.saved_tensors
129
+ # reweigh gradient pixelwise so that rare colors get a chance to
130
+ # contribute
131
+ grad_input = grad_output * weights
132
+ # second return value is None since we are not interested in the
133
+ # gradient with respect to the weights
134
+ return grad_input, None
135
+
136
+
137
+ class GetClassWeights:
138
+ def __init__(self, cielab, lambda_=0.5, device='cuda'):
139
+ prior = torch.from_numpy(cielab.gamut.prior).cuda()
140
+ uniform = torch.zeros_like(prior)
141
+ uniform[prior > 0] = 1 / (prior > 0).sum().type_as(uniform)
142
+ self.weights = 1 / ((1 - lambda_) * prior + lambda_ * uniform)
143
+ self.weights /= torch.sum(prior * self.weights)
144
+
145
+ def __call__(self, ab_actual):
146
+ return self.weights[ab_actual.argmax(dim=1, keepdim=True)]
147
+
148
+
149
+ class ColorLabel:
150
+ def __init__(self, lambda_=0.5, device='cuda'):
151
+ self.cielab = cielab.CIELAB()
152
+ self.q_to_ab = torch.from_numpy(self.cielab.q_to_ab).to(device)
153
+ prior = torch.from_numpy(self.cielab.gamut.prior).to(device)
154
+ uniform = torch.zeros_like(prior)
155
+ uniform[prior>0] = 1 / (prior>0).sum().type_as(uniform)
156
+ self.weights = 1 / ((1-lambda_) * prior + lambda_ * uniform)
157
+ self.weights /= torch.sum(prior * self.weights)
158
+
159
+ def visualize_label(self, step=3):
160
+ height, width = 200, 313*step
161
+ label_lab = np.ones((height,width,3), np.float32)
162
+ for x in range(313):
163
+ ab = self.cielab.q_to_ab[x,:]
164
+ label_lab[:,step*x:step*(x+1),1:] = ab / 110.
165
+ label_lab[:,:,0] = np.zeros((height,width), np.float32)
166
+ return label_lab
167
+
168
+ @staticmethod
169
+ def _gauss_eval(x, mu, sigma):
170
+ norm = 1 / (2 * math.pi * sigma)
171
+ return norm * torch.exp(-torch.sum((x - mu)**2, dim=0) / (2 * sigma**2))
172
+
173
+ def get_classweights(self, batch_gt_indx):
174
+ #return self.weights[batch_gt_q.argmax(dim=1, keepdim=True)]
175
+ return self.weights[batch_gt_indx]
176
+
177
+ def encode_ab2ind(self, batch_ab, neighbours=5, sigma=5.0):
178
+ batch_ab = batch_ab * 110.
179
+ n, _, h, w = batch_ab.shape
180
+ m = n * h * w
181
+ # find nearest neighbours
182
+ ab_ = batch_ab.permute(1, 0, 2, 3).reshape(2, -1) # (2, n*h*w)
183
+ cdist = torch.cdist(self.q_to_ab, ab_.t())
184
+ nns = cdist.argsort(dim=0)[:neighbours, :]
185
+ # gaussian weighting
186
+ nn_gauss = batch_ab.new_zeros(neighbours, m)
187
+ for i in range(neighbours):
188
+ nn_gauss[i, :] = self._gauss_eval(self.q_to_ab[nns[i, :], :].t(), ab_, sigma)
189
+ nn_gauss /= nn_gauss.sum(dim=0, keepdim=True)
190
+ # expand
191
+ bins = self.cielab.gamut.EXPECTED_SIZE
192
+ q = batch_ab.new_zeros(bins, m)
193
+ q[nns, torch.arange(m).repeat(neighbours, 1)] = nn_gauss
194
+ return q.reshape(bins, n, h, w).permute(1, 0, 2, 3)
195
+
196
+ def decode_ind2ab(self, batch_q, T=0.38):
197
+ _, _, h, w = batch_q.shape
198
+ batch_q = F.softmax(batch_q, dim=1)
199
+ if T%1 == 0:
200
+ # take the T-st probable index
201
+ sorted_probs, batch_indexs = torch.sort(batch_q, dim=1, descending=True)
202
+ #print('checking [index]', batch_indexs[:,0:5,5,5])
203
+ #print('checking [probs]', sorted_probs[:,0:5,5,5])
204
+ batch_indexs = batch_indexs[:,T:T+1,:,:]
205
+ #batch_indexs = torch.where(sorted_probs[:,T:T+1,:,:] > 0.25, batch_indexs[:,T:T+1,:,:], batch_indexs[:,0:1,:,:])
206
+ ab = torch.stack([
207
+ self.q_to_ab.index_select(0, q_i.flatten()).reshape(h,w,2).permute(2,0,1)
208
+ for q_i in batch_indexs])
209
+ else:
210
+ batch_q = torch.exp(batch_q / T)
211
+ batch_q /= batch_q.sum(dim=1, keepdim=True)
212
+ a = torch.tensordot(batch_q, self.q_to_ab[:,0], dims=((1,), (0,)))
213
+ a = a.unsqueeze(dim=1)
214
+ b = torch.tensordot(batch_q, self.q_to_ab[:,1], dims=((1,), (0,)))
215
+ b = b.unsqueeze(dim=1)
216
+ ab = torch.cat((a, b), dim=1)
217
+ ab = ab / 110.
218
+ return ab.type(batch_q.dtype)
219
+
220
+
221
+ def init_spixel_grid(img_height, img_width, spixel_size=16):
222
+ # get spixel id for the final assignment
223
+ n_spixl_h = int(np.floor(img_height/spixel_size))
224
+ n_spixl_w = int(np.floor(img_width/spixel_size))
225
+ spixel_height = int(img_height / (1. * n_spixl_h))
226
+ spixel_width = int(img_width / (1. * n_spixl_w))
227
+ spix_values = np.int32(np.arange(0, n_spixl_w * n_spixl_h).reshape((n_spixl_h, n_spixl_w)))
228
+
229
+ def shift9pos(input, h_shift_unit=1, w_shift_unit=1):
230
+ # input should be padding as (c, 1+ height+1, 1+width+1)
231
+ input_pd = np.pad(input, ((h_shift_unit, h_shift_unit), (w_shift_unit, w_shift_unit)), mode='edge')
232
+ input_pd = np.expand_dims(input_pd, axis=0)
233
+ # assign to ...
234
+ top = input_pd[:, :-2 * h_shift_unit, w_shift_unit:-w_shift_unit]
235
+ bottom = input_pd[:, 2 * h_shift_unit:, w_shift_unit:-w_shift_unit]
236
+ left = input_pd[:, h_shift_unit:-h_shift_unit, :-2 * w_shift_unit]
237
+ right = input_pd[:, h_shift_unit:-h_shift_unit, 2 * w_shift_unit:]
238
+ center = input_pd[:,h_shift_unit:-h_shift_unit,w_shift_unit:-w_shift_unit]
239
+ bottom_right = input_pd[:, 2 * h_shift_unit:, 2 * w_shift_unit:]
240
+ bottom_left = input_pd[:, 2 * h_shift_unit:, :-2 * w_shift_unit]
241
+ top_right = input_pd[:, :-2 * h_shift_unit, 2 * w_shift_unit:]
242
+ top_left = input_pd[:, :-2 * h_shift_unit, :-2 * w_shift_unit]
243
+ shift_tensor = np.concatenate([ top_left, top, top_right,
244
+ left, center, right,
245
+ bottom_left, bottom, bottom_right], axis=0)
246
+ return shift_tensor
247
+
248
+ spix_idx_tensor_ = shift9pos(spix_values)
249
+ spix_idx_tensor = np.repeat(
250
+ np.repeat(spix_idx_tensor_, spixel_height, axis=1), spixel_width, axis=2)
251
+ spixel_id_tensor = torch.from_numpy(spix_idx_tensor).type(torch.float)
252
+
253
+ #! pixel coord feature maps
254
+ all_h_coords = np.arange(0, img_height, 1)
255
+ all_w_coords = np.arange(0, img_width, 1)
256
+ curr_pxl_coord = np.array(np.meshgrid(all_h_coords, all_w_coords, indexing='ij'))
257
+ coord_feat_tensor = np.concatenate([curr_pxl_coord[1:2, :, :], curr_pxl_coord[:1, :, :]])
258
+ coord_feat_tensor = torch.from_numpy(coord_feat_tensor).type(torch.float)
259
+
260
+ return spixel_id_tensor, coord_feat_tensor
261
+
262
+
263
+ def split_spixels(assign_map, spixel_ids):
264
+ N,C,H,W = assign_map.shape
265
+ spixel_id_map = spixel_ids.expand(N,-1,-1,-1)
266
+ assig_max,_ = torch.max(assign_map, dim=1, keepdim=True)
267
+ assignment_ = torch.where(assign_map == assig_max, torch.ones(assign_map.shape).cuda(),torch.zeros(assign_map.shape).cuda())
268
+ ## winner take all
269
+ new_spixl_map_ = spixel_id_map * assignment_
270
+ new_spixl_map = torch.sum(new_spixl_map_,dim=1,keepdim=True).type(torch.int)
271
+ return new_spixl_map
272
+
273
+
274
+ def poolfeat(input, prob, sp_h=2, sp_w=2, need_entry_prob=False):
275
+ def feat_prob_sum(feat_sum, prob_sum, shift_feat):
276
+ feat_sum += shift_feat[:, :-1, :, :]
277
+ prob_sum += shift_feat[:, -1:, :, :]
278
+ return feat_sum, prob_sum
279
+
280
+ b, _, h, w = input.shape
281
+ h_shift_unit = 1
282
+ w_shift_unit = 1
283
+ p2d = (w_shift_unit, w_shift_unit, h_shift_unit, h_shift_unit)
284
+ feat_ = torch.cat([input, torch.ones([b, 1, h, w], device=input.device)], dim=1) # b* (n+1) *h*w
285
+ prob_feat = F.avg_pool2d(feat_ * prob.narrow(1, 0, 1), kernel_size=(sp_h, sp_w), stride=(sp_h, sp_w)) # b * (n+1) * h* w
286
+ send_to_top_left = F.pad(prob_feat, p2d, mode='constant', value=0)[:, :, 2 * h_shift_unit:, 2 * w_shift_unit:]
287
+ feat_sum = send_to_top_left[:, :-1, :, :].clone()
288
+ prob_sum = send_to_top_left[:, -1:, :, :].clone()
289
+
290
+ prob_feat = F.avg_pool2d(feat_ * prob.narrow(1, 1, 1), kernel_size=(sp_h, sp_w), stride=(sp_h, sp_w)) # b * (n+1) * h* w
291
+ top = F.pad(prob_feat, p2d, mode='constant', value=0)[:, :, 2 * h_shift_unit:, w_shift_unit:-w_shift_unit]
292
+ feat_sum, prob_sum = feat_prob_sum(feat_sum, prob_sum, top)
293
+
294
+ prob_feat = F.avg_pool2d(feat_ * prob.narrow(1, 2, 1), kernel_size=(sp_h, sp_w), stride=(sp_h, sp_w)) # b * (n+1) * h* w
295
+ top_right = F.pad(prob_feat, p2d, mode='constant', value=0)[:, :, 2 * h_shift_unit:, :-2 * w_shift_unit]
296
+ feat_sum, prob_sum = feat_prob_sum(feat_sum, prob_sum, top_right)
297
+
298
+ prob_feat = F.avg_pool2d(feat_ * prob.narrow(1, 3, 1), kernel_size=(sp_h, sp_w), stride=(sp_h, sp_w)) # b * (n+1) * h* w
299
+ left = F.pad(prob_feat, p2d, mode='constant', value=0)[:, :, h_shift_unit:-h_shift_unit, 2 * w_shift_unit:]
300
+ feat_sum, prob_sum = feat_prob_sum(feat_sum, prob_sum, left)
301
+
302
+ prob_feat = F.avg_pool2d(feat_ * prob.narrow(1, 4, 1), kernel_size=(sp_h, sp_w), stride=(sp_h, sp_w)) # b * (n+1) * h* w
303
+ center = F.pad(prob_feat, p2d, mode='constant', value=0)[:, :, h_shift_unit:-h_shift_unit, w_shift_unit:-w_shift_unit]
304
+ feat_sum, prob_sum = feat_prob_sum(feat_sum, prob_sum, center)
305
+
306
+ prob_feat = F.avg_pool2d(feat_ * prob.narrow(1, 5, 1), kernel_size=(sp_h, sp_w), stride=(sp_h, sp_w)) # b * (n+1) * h* w
307
+ right = F.pad(prob_feat, p2d, mode='constant', value=0)[:, :, h_shift_unit:-h_shift_unit, :-2 * w_shift_unit]
308
+ feat_sum, prob_sum = feat_prob_sum(feat_sum, prob_sum, right)
309
+
310
+ prob_feat = F.avg_pool2d(feat_ * prob.narrow(1, 6, 1), kernel_size=(sp_h, sp_w), stride=(sp_h, sp_w)) # b * (n+1) * h* w
311
+ bottom_left = F.pad(prob_feat, p2d, mode='constant', value=0)[:, :, :-2 * h_shift_unit, 2 * w_shift_unit:]
312
+ feat_sum, prob_sum = feat_prob_sum(feat_sum, prob_sum, bottom_left)
313
+
314
+ prob_feat = F.avg_pool2d(feat_ * prob.narrow(1, 7, 1), kernel_size=(sp_h, sp_w), stride=(sp_h, sp_w)) # b * (n+1) * h* w
315
+ bottom = F.pad(prob_feat, p2d, mode='constant', value=0)[:, :, :-2 * h_shift_unit, w_shift_unit:-w_shift_unit]
316
+ feat_sum, prob_sum = feat_prob_sum(feat_sum, prob_sum, bottom)
317
+
318
+ prob_feat = F.avg_pool2d(feat_ * prob.narrow(1, 8, 1), kernel_size=(sp_h, sp_w), stride=(sp_h, sp_w)) # b * (n+1) * h* w
319
+ bottom_right = F.pad(prob_feat, p2d, mode='constant', value=0)[:, :, :-2 * h_shift_unit, :-2 * w_shift_unit]
320
+ feat_sum, prob_sum = feat_prob_sum(feat_sum, prob_sum, bottom_right)
321
+ pooled_feat = feat_sum / (prob_sum + 1e-8)
322
+ if need_entry_prob:
323
+ return pooled_feat, prob_sum
324
+ return pooled_feat
325
+
326
+
327
+ def get_spixel_size(affinity_map, sp_h=2, sp_w=2, elem_thres=25):
328
+ N,C,H,W = affinity_map.shape
329
+ device = affinity_map.device
330
+ assign_max,_ = torch.max(affinity_map, dim=1, keepdim=True)
331
+ assign_map = torch.where(affinity_map==assign_max, torch.ones(affinity_map.shape, device=device), torch.zeros(affinity_map.shape, device=device))
332
+ ## one_map = (N,1,H,W)
333
+ _, elem_num_maps = poolfeat(torch.ones(assign_max.shape, device=device), assign_map, sp_h, sp_w, True)
334
+ #all_one_map = torch.ones(elem_num_maps.shape).cuda()
335
+ #empty_mask = torch.where(elem_num_maps < elem_thres/256, all_one_map, 1-all_one_map)
336
+ return elem_num_maps
337
+
338
+
339
+ def upfeat(input, prob, up_h=2, up_w=2):
340
+ # input b*n*H*W downsampled
341
+ # prob b*9*h*w
342
+ b, c, h, w = input.shape
343
+
344
+ h_shift = 1
345
+ w_shift = 1
346
+
347
+ p2d = (w_shift, w_shift, h_shift, h_shift)
348
+ feat_pd = F.pad(input, p2d, mode='constant', value=0)
349
+
350
+ gt_frm_top_left = F.interpolate(feat_pd[:, :, :-2 * h_shift, :-2 * w_shift], size=(h * up_h, w * up_w),mode='nearest')
351
+ feat_sum = gt_frm_top_left * prob.narrow(1,0,1)
352
+
353
+ top = F.interpolate(feat_pd[:, :, :-2 * h_shift, w_shift:-w_shift], size=(h * up_h, w * up_w), mode='nearest')
354
+ feat_sum += top * prob.narrow(1, 1, 1)
355
+
356
+ top_right = F.interpolate(feat_pd[:, :, :-2 * h_shift, 2 * w_shift:], size=(h * up_h, w * up_w), mode='nearest')
357
+ feat_sum += top_right * prob.narrow(1,2,1)
358
+
359
+ left = F.interpolate(feat_pd[:, :, h_shift:-w_shift, :-2 * w_shift], size=(h * up_h, w * up_w), mode='nearest')
360
+ feat_sum += left * prob.narrow(1, 3, 1)
361
+
362
+ center = F.interpolate(input, (h * up_h, w * up_w), mode='nearest')
363
+ feat_sum += center * prob.narrow(1, 4, 1)
364
+
365
+ right = F.interpolate(feat_pd[:, :, h_shift:-w_shift, 2 * w_shift:], size=(h * up_h, w * up_w), mode='nearest')
366
+ feat_sum += right * prob.narrow(1, 5, 1)
367
+
368
+ bottom_left = F.interpolate(feat_pd[:, :, 2 * h_shift:, :-2 * w_shift], size=(h * up_h, w * up_w), mode='nearest')
369
+ feat_sum += bottom_left * prob.narrow(1, 6, 1)
370
+
371
+ bottom = F.interpolate(feat_pd[:, :, 2 * h_shift:, w_shift:-w_shift], size=(h * up_h, w * up_w), mode='nearest')
372
+ feat_sum += bottom * prob.narrow(1, 7, 1)
373
+
374
+ bottom_right = F.interpolate(feat_pd[:, :, 2 * h_shift:, 2 * w_shift:], size=(h * up_h, w * up_w), mode='nearest')
375
+ feat_sum += bottom_right * prob.narrow(1, 8, 1)
376
+
377
+ return feat_sum
378
+
379
+
380
+ def suck_and_spread(self, base_maps, seg_layers):
381
+ N,S,H,W = seg_layers.shape
382
+ base_maps = base_maps.unsqueeze(1)
383
+ seg_layers = seg_layers.unsqueeze(2)
384
+ ## (N,S,C,1,1) = (N,1,C,H,W) * (N,S,1,H,W)
385
+ mean_val_layers = (base_maps * seg_layers).sum(dim=(3,4), keepdim=True) / (1e-5 + seg_layers.sum(dim=(3,4), keepdim=True))
386
+ ## normalized to be sum one
387
+ weight_layers = seg_layers / (1e-5 + torch.sum(seg_layers, dim=1, keepdim=True))
388
+ ## (N,S,C,H,W) = (N,S,C,1,1) * (N,S,1,H,W)
389
+ recon_maps = mean_val_layers * weight_layers
390
+ return recon_maps.sum(dim=1)
391
+
392
+
393
+ #! copy from Richard Zhang [SIGGRAPH2017]
394
+ # RGB grid points maps to Lab range: L[0,100], a[-86.183,98,233], b[-107.857,94.478]
395
+ #------------------------------------------------------------------------------
396
+ def rgb2xyz(rgb): # rgb from [0,1]
397
+ # xyz_from_rgb = np.array([[0.412453, 0.357580, 0.180423],
398
+ # [0.212671, 0.715160, 0.072169],
399
+ # [0.019334, 0.119193, 0.950227]])
400
+ mask = (rgb > .04045).type(torch.FloatTensor)
401
+ if(rgb.is_cuda):
402
+ mask = mask.cuda()
403
+ rgb = (((rgb+.055)/1.055)**2.4)*mask + rgb/12.92*(1-mask)
404
+ x = .412453*rgb[:,0,:,:]+.357580*rgb[:,1,:,:]+.180423*rgb[:,2,:,:]
405
+ y = .212671*rgb[:,0,:,:]+.715160*rgb[:,1,:,:]+.072169*rgb[:,2,:,:]
406
+ z = .019334*rgb[:,0,:,:]+.119193*rgb[:,1,:,:]+.950227*rgb[:,2,:,:]
407
+ out = torch.cat((x[:,None,:,:],y[:,None,:,:],z[:,None,:,:]),dim=1)
408
+ return out
409
+
410
+ def xyz2rgb(xyz):
411
+ # array([[ 3.24048134, -1.53715152, -0.49853633],
412
+ # [-0.96925495, 1.87599 , 0.04155593],
413
+ # [ 0.05564664, -0.20404134, 1.05731107]])
414
+ r = 3.24048134*xyz[:,0,:,:]-1.53715152*xyz[:,1,:,:]-0.49853633*xyz[:,2,:,:]
415
+ g = -0.96925495*xyz[:,0,:,:]+1.87599*xyz[:,1,:,:]+.04155593*xyz[:,2,:,:]
416
+ b = .05564664*xyz[:,0,:,:]-.20404134*xyz[:,1,:,:]+1.05731107*xyz[:,2,:,:]
417
+ rgb = torch.cat((r[:,None,:,:],g[:,None,:,:],b[:,None,:,:]),dim=1)
418
+ #! sometimes reaches a small negative number, which causes NaNs
419
+ rgb = torch.max(rgb,torch.zeros_like(rgb))
420
+ mask = (rgb > .0031308).type(torch.FloatTensor)
421
+ if(rgb.is_cuda):
422
+ mask = mask.cuda()
423
+ rgb = (1.055*(rgb**(1./2.4)) - 0.055)*mask + 12.92*rgb*(1-mask)
424
+ return rgb
425
+
426
+ def xyz2lab(xyz):
427
+ # 0.95047, 1., 1.08883 # white
428
+ sc = torch.Tensor((0.95047, 1., 1.08883))[None,:,None,None]
429
+ if(xyz.is_cuda):
430
+ sc = sc.cuda()
431
+ xyz_scale = xyz/sc
432
+ mask = (xyz_scale > .008856).type(torch.FloatTensor)
433
+ if(xyz_scale.is_cuda):
434
+ mask = mask.cuda()
435
+ xyz_int = xyz_scale**(1/3.)*mask + (7.787*xyz_scale + 16./116.)*(1-mask)
436
+ L = 116.*xyz_int[:,1,:,:]-16.
437
+ a = 500.*(xyz_int[:,0,:,:]-xyz_int[:,1,:,:])
438
+ b = 200.*(xyz_int[:,1,:,:]-xyz_int[:,2,:,:])
439
+ out = torch.cat((L[:,None,:,:],a[:,None,:,:],b[:,None,:,:]),dim=1)
440
+ return out
441
+
442
+ def lab2xyz(lab):
443
+ y_int = (lab[:,0,:,:]+16.)/116.
444
+ x_int = (lab[:,1,:,:]/500.) + y_int
445
+ z_int = y_int - (lab[:,2,:,:]/200.)
446
+ if(z_int.is_cuda):
447
+ z_int = torch.max(torch.Tensor((0,)).cuda(), z_int)
448
+ else:
449
+ z_int = torch.max(torch.Tensor((0,)), z_int)
450
+ out = torch.cat((x_int[:,None,:,:],y_int[:,None,:,:],z_int[:,None,:,:]),dim=1)
451
+ mask = (out > .2068966).type(torch.FloatTensor)
452
+ if(out.is_cuda):
453
+ mask = mask.cuda()
454
+ out = (out**3.)*mask + (out - 16./116.)/7.787*(1-mask)
455
+ sc = torch.Tensor((0.95047, 1., 1.08883))[None,:,None,None]
456
+ sc = sc.to(out.device)
457
+ out = out*sc
458
+ return out
459
+
460
+ def rgb2lab(rgb, l_mean=50, l_norm=50, ab_norm=110):
461
+ #! input rgb: [0,1]
462
+ #! output lab: [-1,1]
463
+ lab = xyz2lab(rgb2xyz(rgb))
464
+ l_rs = (lab[:,[0],:,:]-l_mean) / l_norm
465
+ ab_rs = lab[:,1:,:,:] / ab_norm
466
+ out = torch.cat((l_rs,ab_rs),dim=1)
467
+ return out
468
+
469
+ def lab2rgb(lab_rs, l_mean=50, l_norm=50, ab_norm=110):
470
+ #! input lab: [-1,1]
471
+ #! output rgb: [0,1]
472
+ l_ = lab_rs[:,[0],:,:] * l_norm + l_mean
473
+ ab = lab_rs[:,1:,:,:] * ab_norm
474
+ lab = torch.cat((l_,ab), dim=1)
475
+ out = xyz2rgb(lab2xyz(lab))
476
+ return out
477
+
478
+
479
+ if __name__ == '__main__':
480
+ minL, minA, minB = 999., 999., 999.
481
+ maxL, maxA, maxB = 0., 0., 0.
482
+ for r in range(256):
483
+ print('h',r)
484
+ for g in range(256):
485
+ for b in range(256):
486
+ rgb = np.array([r,g,b], np.float32).reshape(1,1,-1) / 255.0
487
+ #lab_img = cv2.cvtColor(rgb, cv2.COLOR_RGB2LAB)
488
+ rgb = torch.from_numpy(rgb.transpose((2, 0, 1)))
489
+ rgb = rgb.reshape(1,3,1,1)
490
+ lab = rgb2lab(rgb)
491
+ lab[:,[0],:,:] = lab[:,[0],:,:] * 50 + 50
492
+ lab[:,1:,:,:] = lab[:,1:,:,:] * 110
493
+ lab = lab.squeeze()
494
+ lab_float = lab.numpy()
495
+ #print('zhang vs. cv2:', lab_float, lab_img.squeeze())
496
+ minL = min(lab_float[0], minL)
497
+ minA = min(lab_float[1], minA)
498
+ minB = min(lab_float[2], minB)
499
+ maxL = max(lab_float[0], maxL)
500
+ maxA = max(lab_float[1], maxA)
501
+ maxB = max(lab_float[2], maxB)
502
+ print('L:', minL, maxL)
503
+ print('A:', minA, maxA)
504
+ print('B:', minB, maxB)
models/clusterkit.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from functools import partial
6
+ import numpy as np
7
+ import torch
8
+ from tqdm import tqdm
9
+ import math, random
10
+ #from sklearn.cluster import KMeans, kmeans_plusplus, MeanShift, estimate_bandwidth
11
+
12
+
13
+ def tensor_kmeans_sklearn(data_vecs, n_clusters=7, metric='euclidean', need_layer_masks=False, max_iters=20):
14
+ N,C,H,W = data_vecs.shape
15
+ assert N == 1, 'only support singe image tensor'
16
+ ## (1,C,H,W) -> (HW,C)
17
+ data_vecs = data_vecs.permute(0,2,3,1).view(-1,C)
18
+ ## convert tensor to array
19
+ data_vecs_np = data_vecs.squeeze().detach().to("cpu").numpy()
20
+ km = KMeans(n_clusters=n_clusters, init='k-means++', n_init=10, max_iter=300)
21
+ pred = km.fit_predict(data_vecs_np)
22
+ cluster_ids_x = torch.from_numpy(km.labels_).to(data_vecs.device)
23
+ id_maps = cluster_ids_x.reshape(1,1,H,W).long()
24
+ if need_layer_masks:
25
+ one_hot_labels = F.one_hot(id_maps.squeeze(1), num_classes=n_clusters).float()
26
+ cluster_mask = one_hot_labels.permute(0,3,1,2)
27
+ return cluster_mask
28
+ return id_maps
29
+
30
+
31
+ def tensor_kmeans_pytorch(data_vecs, n_clusters=7, metric='euclidean', need_layer_masks=False, max_iters=20):
32
+ N,C,H,W = data_vecs.shape
33
+ assert N == 1, 'only support singe image tensor'
34
+
35
+ ## (1,C,H,W) -> (HW,C)
36
+ data_vecs = data_vecs.permute(0,2,3,1).view(-1,C)
37
+ ## cosine | euclidean
38
+ #cluster_ids_x, cluster_centers = kmeans(X=data_vecs, num_clusters=n_clusters, distance=metric, device=data_vecs.device)
39
+ cluster_ids_x, cluster_centers = kmeans(X=data_vecs, num_clusters=n_clusters, distance=metric,\
40
+ tqdm_flag=False, iter_limit=max_iters, device=data_vecs.device)
41
+ id_maps = cluster_ids_x.reshape(1,1,H,W)
42
+ if need_layer_masks:
43
+ one_hot_labels = F.one_hot(id_maps.squeeze(1), num_classes=n_clusters).float()
44
+ cluster_mask = one_hot_labels.permute(0,3,1,2)
45
+ return cluster_mask
46
+ return id_maps
47
+
48
+
49
+ def batch_kmeans_pytorch(data_vecs, n_clusters=7, metric='euclidean', use_sklearn_kmeans=False):
50
+ N,C,H,W = data_vecs.shape
51
+ sample_list = []
52
+ for idx in range(N):
53
+ if use_sklearn_kmeans:
54
+ cluster_mask = tensor_kmeans_sklearn(data_vecs[idx:idx+1,:,:,:], n_clusters, metric, True)
55
+ else:
56
+ cluster_mask = tensor_kmeans_pytorch(data_vecs[idx:idx+1,:,:,:], n_clusters, metric, True)
57
+ sample_list.append(cluster_mask)
58
+ return torch.cat(sample_list, dim=0)
59
+
60
+
61
+ def get_centroid_candidates(data_vecs, n_clusters=7, metric='euclidean', max_iters=20):
62
+ N,C,H,W = data_vecs.shape
63
+ data_vecs = data_vecs.permute(0,2,3,1).view(-1,C)
64
+ cluster_ids_x, cluster_centers = kmeans(X=data_vecs, num_clusters=n_clusters, distance=metric,\
65
+ tqdm_flag=False, iter_limit=max_iters, device=data_vecs.device)
66
+ return cluster_centers
67
+
68
+
69
+ def find_distinctive_elements(data_tensor, n_clusters=7, topk=3, metric='euclidean'):
70
+ N,C,H,W = data_tensor.shape
71
+ centroid_list = []
72
+ for idx in range(N):
73
+ cluster_centers = get_centroid_candidates(data_tensor[idx:idx+1,:,:,:], n_clusters, metric)
74
+ centroid_list.append(cluster_centers)
75
+
76
+ batch_centroids = torch.stack(centroid_list, dim=0)
77
+ data_vecs = data_tensor.flatten(2)
78
+ ## distance matrix: (N,K,HW) = (N,K,C) x (N,C,HW)
79
+ AtB = torch.matmul(batch_centroids, data_vecs)
80
+ AtA = torch.matmul(batch_centroids, batch_centroids.permute(0,2,1))
81
+ BtB = torch.matmul(data_vecs.permute(0,2,1), data_vecs)
82
+ diag_A = torch.diagonal(AtA, dim1=-2, dim2=-1)
83
+ diag_B = torch.diagonal(BtB, dim1=-2, dim2=-1)
84
+ A2 = diag_A.unsqueeze(2).repeat(1,1,H*W)
85
+ B2 = diag_B.unsqueeze(1).repeat(1,n_clusters,1)
86
+ distance_map = A2 - 2*AtB + B2
87
+ values, indices = distance_map.topk(topk, dim=2, largest=False, sorted=True)
88
+ cluster_mask = torch.where(distance_map <= values[:,:,topk-1:], torch.ones_like(distance_map), torch.zeros_like(distance_map))
89
+ cluster_mask = cluster_mask.view(N,n_clusters,H,W)
90
+ return cluster_mask
91
+
92
+
93
+ ##---------------------------------------------------------------------------------
94
+ '''
95
+ resource from github: https://github.com/subhadarship/kmeans_pytorch
96
+ '''
97
+ ##---------------------------------------------------------------------------------
98
+
99
+ def initialize(X, num_clusters):
100
+ """
101
+ initialize cluster centers
102
+ :param X: (torch.tensor) matrix
103
+ :param num_clusters: (int) number of clusters
104
+ :return: (np.array) initial state
105
+ """
106
+ num_samples = len(X)
107
+ indices = np.random.choice(num_samples, num_clusters, replace=False)
108
+ initial_state = X[indices]
109
+ return initial_state
110
+
111
+
112
+ def kmeans(
113
+ X,
114
+ num_clusters,
115
+ distance='euclidean',
116
+ cluster_centers=[],
117
+ tol=1e-4,
118
+ tqdm_flag=True,
119
+ iter_limit=0,
120
+ device=torch.device('cpu'),
121
+ gamma_for_soft_dtw=0.001
122
+ ):
123
+ """
124
+ perform kmeans
125
+ :param X: (torch.tensor) matrix
126
+ :param num_clusters: (int) number of clusters
127
+ :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean']
128
+ :param tol: (float) threshold [default: 0.0001]
129
+ :param device: (torch.device) device [default: cpu]
130
+ :param tqdm_flag: Allows to turn logs on and off
131
+ :param iter_limit: hard limit for max number of iterations
132
+ :param gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0
133
+ :return: (torch.tensor, torch.tensor) cluster ids, cluster centers
134
+ """
135
+ if tqdm_flag:
136
+ print(f'running k-means on {device}..')
137
+
138
+ if distance == 'euclidean':
139
+ pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag)
140
+ elif distance == 'cosine':
141
+ pairwise_distance_function = partial(pairwise_cosine, device=device)
142
+ else:
143
+ raise NotImplementedError
144
+
145
+ # convert to float
146
+ X = X.float()
147
+
148
+ # transfer to device
149
+ X = X.to(device)
150
+
151
+ # initialize
152
+ if type(cluster_centers) == list: # ToDo: make this less annoyingly weird
153
+ initial_state = initialize(X, num_clusters)
154
+ else:
155
+ if tqdm_flag:
156
+ print('resuming')
157
+ # find data point closest to the initial cluster center
158
+ initial_state = cluster_centers
159
+ dis = pairwise_distance_function(X, initial_state)
160
+ choice_points = torch.argmin(dis, dim=0)
161
+ initial_state = X[choice_points]
162
+ initial_state = initial_state.to(device)
163
+
164
+ iteration = 0
165
+ if tqdm_flag:
166
+ tqdm_meter = tqdm(desc='[running kmeans]')
167
+ while True:
168
+
169
+ dis = pairwise_distance_function(X, initial_state)
170
+
171
+ choice_cluster = torch.argmin(dis, dim=1)
172
+
173
+ initial_state_pre = initial_state.clone()
174
+
175
+ for index in range(num_clusters):
176
+ selected = torch.nonzero(choice_cluster == index).squeeze().to(device)
177
+
178
+ selected = torch.index_select(X, 0, selected)
179
+
180
+ # https://github.com/subhadarship/kmeans_pytorch/issues/16
181
+ if selected.shape[0] == 0:
182
+ selected = X[torch.randint(len(X), (1,))]
183
+
184
+ initial_state[index] = selected.mean(dim=0)
185
+
186
+ center_shift = torch.sum(
187
+ torch.sqrt(
188
+ torch.sum((initial_state - initial_state_pre) ** 2, dim=1)
189
+ ))
190
+
191
+ # increment iteration
192
+ iteration = iteration + 1
193
+
194
+ # update tqdm meter
195
+ if tqdm_flag:
196
+ tqdm_meter.set_postfix(
197
+ iteration=f'{iteration}',
198
+ center_shift=f'{center_shift ** 2:0.6f}',
199
+ tol=f'{tol:0.6f}'
200
+ )
201
+ tqdm_meter.update()
202
+ if center_shift ** 2 < tol:
203
+ break
204
+ if iter_limit != 0 and iteration >= iter_limit:
205
+ #print('hello, there!')
206
+ break
207
+
208
+ return choice_cluster.to(device), initial_state.to(device)
209
+
210
+
211
+ def kmeans_predict(
212
+ X,
213
+ cluster_centers,
214
+ distance='euclidean',
215
+ device=torch.device('cpu'),
216
+ gamma_for_soft_dtw=0.001,
217
+ tqdm_flag=True
218
+ ):
219
+ """
220
+ predict using cluster centers
221
+ :param X: (torch.tensor) matrix
222
+ :param cluster_centers: (torch.tensor) cluster centers
223
+ :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean']
224
+ :param device: (torch.device) device [default: 'cpu']
225
+ :param gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0
226
+ :return: (torch.tensor) cluster ids
227
+ """
228
+ if tqdm_flag:
229
+ print(f'predicting on {device}..')
230
+
231
+ if distance == 'euclidean':
232
+ pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag)
233
+ elif distance == 'cosine':
234
+ pairwise_distance_function = partial(pairwise_cosine, device=device)
235
+ elif distance == 'soft_dtw':
236
+ sdtw = SoftDTW(use_cuda=device.type == 'cuda', gamma=gamma_for_soft_dtw)
237
+ pairwise_distance_function = partial(pairwise_soft_dtw, sdtw=sdtw, device=device)
238
+ else:
239
+ raise NotImplementedError
240
+
241
+ # convert to float
242
+ X = X.float()
243
+
244
+ # transfer to device
245
+ X = X.to(device)
246
+
247
+ dis = pairwise_distance_function(X, cluster_centers)
248
+ choice_cluster = torch.argmin(dis, dim=1)
249
+
250
+ return choice_cluster.cpu()
251
+
252
+
253
+ def pairwise_distance(data1, data2, device=torch.device('cpu'), tqdm_flag=True):
254
+ if tqdm_flag:
255
+ print(f'device is :{device}')
256
+
257
+ # transfer to device
258
+ data1, data2 = data1.to(device), data2.to(device)
259
+
260
+ # N*1*M
261
+ A = data1.unsqueeze(dim=1)
262
+
263
+ # 1*N*M
264
+ B = data2.unsqueeze(dim=0)
265
+
266
+ dis = (A - B) ** 2.0
267
+ # return N*N matrix for pairwise distance
268
+ dis = dis.sum(dim=-1).squeeze()
269
+ return dis
270
+
271
+
272
+ def pairwise_cosine(data1, data2, device=torch.device('cpu')):
273
+ # transfer to device
274
+ data1, data2 = data1.to(device), data2.to(device)
275
+
276
+ # N*1*M
277
+ A = data1.unsqueeze(dim=1)
278
+
279
+ # 1*N*M
280
+ B = data2.unsqueeze(dim=0)
281
+
282
+ # normalize the points | [0.3, 0.4] -> [0.3/sqrt(0.09 + 0.16), 0.4/sqrt(0.09 + 0.16)] = [0.3/0.5, 0.4/0.5]
283
+ A_normalized = A / A.norm(dim=-1, keepdim=True)
284
+ B_normalized = B / B.norm(dim=-1, keepdim=True)
285
+
286
+ cosine = A_normalized * B_normalized
287
+
288
+ # return N*N matrix for pairwise distance
289
+ cosine_dis = 1 - cosine.sum(dim=-1).squeeze()
290
+ return cosine_dis
models/loss.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import os, glob, shutil, math, random, json
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torchvision
7
+ import basic
8
+ from utils import util
9
+
10
+ eps = 0.0000001
11
+
12
+ class SPixelLoss:
13
+ def __init__(self, psize=8, mpdist=False, gpu_no=0):
14
+ self.mpdist = mpdist
15
+ self.gpu_no = gpu_no
16
+ self.sp_size = psize
17
+
18
+ def __call__(self, data, epoch_no):
19
+ kernel_size = self.sp_size
20
+ #pos_weight = 0.003
21
+ prob = data['pred_prob']
22
+ labxy_feat = data['target_feat']
23
+ N,C,H,W = labxy_feat.shape
24
+ pooled_labxy = basic.poolfeat(labxy_feat, prob, kernel_size, kernel_size)
25
+ reconstr_feat = basic.upfeat(pooled_labxy, prob, kernel_size, kernel_size)
26
+ loss_map = reconstr_feat[:,:,:,:] - labxy_feat[:,:,:,:]
27
+ featLoss_idx = torch.norm(loss_map[:,:-2,:,:], p=2, dim=1).mean()
28
+ posLoss_idx = torch.norm(loss_map[:,-2:,:,:], p=2, dim=1).mean() / kernel_size
29
+ totalLoss_idx = 10*featLoss_idx + 0.003*posLoss_idx
30
+ return {'totalLoss':totalLoss_idx, 'featLoss':featLoss_idx, 'posLoss':posLoss_idx}
31
+
32
+
33
+ class AnchorColorProbLoss:
34
+ def __init__(self, hint2regress=False, enhanced=False, with_grad=False, mpdist=False, gpu_no=0):
35
+ self.mpdist = mpdist
36
+ self.gpu_no = gpu_no
37
+ self.hint2regress = hint2regress
38
+ self.enhanced = enhanced
39
+ self.with_grad = with_grad
40
+ self.rebalance_gradient = basic.RebalanceLoss.apply
41
+ self.entropy_loss = nn.CrossEntropyLoss(ignore_index=-1)
42
+ if self.enhanced:
43
+ self.VGGLoss = VGG19Loss(gpu_no=gpu_no, is_ddp=mpdist)
44
+
45
+ def _perceptual_loss(self, input_grays, input_colors, pred_colors):
46
+ input_RGBs = basic.lab2rgb(torch.cat([input_grays,input_colors], dim=1))
47
+ pred_RGBs = basic.lab2rgb(torch.cat([input_grays,pred_colors], dim=1))
48
+ ## the output of "lab2rgb" just matches the input of "VGGLoss": [0,1]
49
+ return self.VGGLoss(input_RGBs, pred_RGBs)
50
+
51
+ def _laplace_gradient(self, pred_AB, target_AB):
52
+ N,C,H,W = pred_AB.shape
53
+ kernel = torch.tensor([[1, 1, 1], [1, -8, 1], [1, 1, 1]], device=pred_AB.get_device()).float()
54
+ kernel = kernel.view(1, 1, *kernel.size()).repeat(C,1,1,1)
55
+ grad_pred = F.conv2d(pred_AB, kernel, groups=C)
56
+ grad_trg = F.conv2d(target_AB, kernel, groups=C)
57
+ return l1_loss(grad_trg, grad_pred)
58
+
59
+ def __call__(self, data, epoch_no):
60
+ N,C,H,W = data['target_label'].shape
61
+ pal_probs = self.rebalance_gradient(data['pal_prob'], data['class_weight'])
62
+ #ref_probs = data['ref_prob']
63
+ pal_probs = pal_probs.permute(0,2,3,1).contiguous().view(N*H*W, -1)
64
+ gt_labels = data['target_label'].permute(0,2,3,1).contiguous().view(N*H*W, -1)
65
+ '''
66
+ igored_mask = data['empty_entries'].permute(0,2,3,1).contiguous().view(N*H*W, -1)
67
+ gt_labels[igored_mask] = -1
68
+ gt_labels = gt_probs.squeeze()
69
+ '''
70
+ palLoss_idx = self.entropy_loss(pal_probs, gt_labels.squeeze(dim=1))
71
+ if self.hint2regress:
72
+ ref_probs = data['ref_prob']
73
+ refLoss_idx = 50 * l2_loss(data['spix_color'], ref_probs)
74
+ else:
75
+ ref_probs = self.rebalance_gradient(data['ref_prob'], data['class_weight'])
76
+ ref_probs = ref_probs.permute(0,2,3,1).contiguous().view(N*H*W, -1)
77
+ refLoss_idx = self.entropy_loss(ref_probs, gt_labels.squeeze(dim=1))
78
+ reconLoss_idx = torch.zeros_like(palLoss_idx)
79
+ if self.enhanced:
80
+ scalar = 1.0 if self.hint2regress else 5.0
81
+ reconLoss_idx = scalar * self._perceptual_loss(data['input_gray'], data['pred_color'], data['input_color'])
82
+ if self.with_grad:
83
+ gradient_loss = self._laplace_gradient(data['pred_color'], data['input_color'])
84
+ reconLoss_idx += gradient_loss
85
+ totalLoss_idx = palLoss_idx + refLoss_idx + reconLoss_idx
86
+ #print("loss terms:", palLoss_idx.item(), refLoss_idx.item(), reconLoss_idx.item())
87
+ return {'totalLoss':totalLoss_idx, 'palLoss':palLoss_idx, 'refLoss':refLoss_idx, 'recLoss':reconLoss_idx}
88
+
89
+
90
+ def compute_affinity_pos_loss(prob_in, labxy_feat, pos_weight=0.003, kernel_size=16):
91
+ S = kernel_size
92
+ m = pos_weight
93
+ prob = prob_in.clone()
94
+ N,C,H,W = labxy_feat.shape
95
+ pooled_labxy = basic.poolfeat(labxy_feat, prob, kernel_size, kernel_size)
96
+ reconstr_feat = basic.upfeat(pooled_labxy, prob, kernel_size, kernel_size)
97
+ loss_map = reconstr_feat[:,:,:,:] - labxy_feat[:,:,:,:]
98
+ loss_feat = torch.norm(loss_map[:,:-2,:,:], p=2, dim=1).mean()
99
+ loss_pos = torch.norm(loss_map[:,-2:,:,:], p=2, dim=1).mean() * m / S
100
+ loss_affinity = loss_feat + loss_pos
101
+ return loss_affinity
102
+
103
+
104
+ def l2_loss(y_input, y_target, weight_map=None):
105
+ if weight_map is None:
106
+ return F.mse_loss(y_input, y_target)
107
+ else:
108
+ diff_map = torch.mean(torch.abs(y_input-y_target), dim=1, keepdim=True)
109
+ batch_dev = torch.sum(diff_map*diff_map*weight_map, dim=(1,2,3)) / (eps+torch.sum(weight_map, dim=(1,2,3)))
110
+ return batch_dev.mean()
111
+
112
+
113
+ def l1_loss(y_input, y_target, weight_map=None):
114
+ if weight_map is None:
115
+ return F.l1_loss(y_input, y_target)
116
+ else:
117
+ diff_map = torch.mean(torch.abs(y_input-y_target), dim=1, keepdim=True)
118
+ batch_dev = torch.sum(diff_map*weight_map, dim=(1,2,3)) / (eps+torch.sum(weight_map, dim=(1,2,3)))
119
+ return batch_dev.mean()
120
+
121
+
122
+ def masked_l1_loss(y_input, y_target, outlier_mask):
123
+ one = torch.tensor([1.0]).cuda(y_input.get_device())
124
+ weight_map = torch.where(outlier_mask, one * 0.0, one * 1.0)
125
+ return l1_loss(y_input, y_target, weight_map)
126
+
127
+
128
+ def huber_loss(y_input, y_target, delta=0.01):
129
+ mask = torch.zeros_like(y_input)
130
+ mann = torch.abs(y_input - y_target)
131
+ eucl = 0.5 * (mann**2)
132
+ mask[...] = mann < delta
133
+ loss = eucl * mask / delta + (mann - 0.5 * delta) * (1 - mask)
134
+ return torch.mean(loss)
135
+
136
+
137
+ ## Perceptual loss that uses a pretrained VGG network
138
+ class VGG19Loss(nn.Module):
139
+ def __init__(self, feat_type='liu', gpu_no=0, is_ddp=False, requires_grad=False):
140
+ super(VGG19Loss, self).__init__()
141
+ os.environ['TORCH_HOME'] = '/apdcephfs/share_1290939/richardxia/Saved/Checkpoints/VGG19'
142
+ ## data requirement: (N,C,H,W) in RGB format, [0,1] range, and resolution >= 224x224
143
+ self.mean = [0.485, 0.456, 0.406]
144
+ self.std = [0.229, 0.224, 0.225]
145
+ self.feat_type = feat_type
146
+
147
+ vgg_model = torchvision.models.vgg19(pretrained=True)
148
+ ## AssertionError: DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient
149
+ '''
150
+ if is_ddp:
151
+ vgg_model = vgg_model.cuda(gpu_no)
152
+ vgg_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(vgg_model)
153
+ vgg_model = torch.nn.parallel.DistributedDataParallel(vgg_model, device_ids=[gpu_no], find_unused_parameters=True)
154
+ else:
155
+ vgg_model = vgg_model.cuda(gpu_no)
156
+ '''
157
+ vgg_model = vgg_model.cuda(gpu_no)
158
+ if self.feat_type == 'liu':
159
+ ## conv1_1, conv2_1, conv3_1, conv4_1, conv5_1
160
+ self.slice1 = nn.Sequential(*list(vgg_model.features)[:2]).eval()
161
+ self.slice2 = nn.Sequential(*list(vgg_model.features)[2:7]).eval()
162
+ self.slice3 = nn.Sequential(*list(vgg_model.features)[7:12]).eval()
163
+ self.slice4 = nn.Sequential(*list(vgg_model.features)[12:21]).eval()
164
+ self.slice5 = nn.Sequential(*list(vgg_model.features)[21:30]).eval()
165
+ self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
166
+ elif self.feat_type == 'lei':
167
+ ## conv1_2, conv2_2, conv3_2, conv4_2, conv5_2
168
+ self.slice1 = nn.Sequential(*list(vgg_model.features)[:4]).eval()
169
+ self.slice2 = nn.Sequential(*list(vgg_model.features)[4:9]).eval()
170
+ self.slice3 = nn.Sequential(*list(vgg_model.features)[9:14]).eval()
171
+ self.slice4 = nn.Sequential(*list(vgg_model.features)[14:23]).eval()
172
+ self.slice5 = nn.Sequential(*list(vgg_model.features)[23:32]).eval()
173
+ self.weights = [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10.0/1.5]
174
+ else:
175
+ ## maxpool after conv4_4
176
+ self.featureExactor = nn.Sequential(*list(vgg_model.features)[:28]).eval()
177
+ '''
178
+ for x in range(2):
179
+ self.slice1.add_module(str(x), pretrained_features[x])
180
+ for x in range(2, 7):
181
+ self.slice2.add_module(str(x), pretrained_features[x])
182
+ for x in range(7, 12):
183
+ self.slice3.add_module(str(x), pretrained_features[x])
184
+ for x in range(12, 21):
185
+ self.slice4.add_module(str(x), pretrained_features[x])
186
+ for x in range(21, 30):
187
+ self.slice5.add_module(str(x), pretrained_features[x])
188
+ '''
189
+ self.criterion = nn.L1Loss()
190
+
191
+ ## fixed parameters
192
+ if not requires_grad:
193
+ for param in self.parameters():
194
+ param.requires_grad = False
195
+ self.eval()
196
+ print('[*] VGG19Loss init!')
197
+
198
+ def normalize(self, tensor):
199
+ tensor = tensor.clone()
200
+ mean = torch.as_tensor(self.mean, dtype=torch.float32, device=tensor.device)
201
+ std = torch.as_tensor(self.std, dtype=torch.float32, device=tensor.device)
202
+ tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
203
+ return tensor
204
+
205
+ def forward(self, x, y):
206
+ norm_x, norm_y = self.normalize(x), self.normalize(y)
207
+ ## feature extract
208
+ if self.feat_type == 'liu' or self.feat_type == 'lei':
209
+ x_relu1, y_relu1 = self.slice1(norm_x), self.slice1(norm_y)
210
+ x_relu2, y_relu2 = self.slice2(x_relu1), self.slice2(y_relu1)
211
+ x_relu3, y_relu3 = self.slice3(x_relu2), self.slice3(y_relu2)
212
+ x_relu4, y_relu4 = self.slice4(x_relu3), self.slice4(y_relu3)
213
+ x_relu5, y_relu5 = self.slice5(x_relu4), self.slice5(y_relu4)
214
+ x_vgg = [x_relu1, x_relu2, x_relu3, x_relu4, x_relu5]
215
+ y_vgg = [y_relu1, y_relu2, y_relu3, y_relu4, y_relu5]
216
+ loss = 0
217
+ for i in range(len(x_vgg)):
218
+ loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
219
+ else:
220
+ x_vgg, y_vgg = self.featureExactor(norm_x), self.featureExactor(norm_y)
221
+ loss = self.criterion(x_vgg, y_vgg.detach())
222
+ return loss
models/model.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from models.network import HourGlass2, SpixelNet, ColorProbNet
5
+ from models.transformer2d import EncoderLayer, DecoderLayer, TransformerEncoder, TransformerDecoder
6
+ from models.position_encoding import build_position_encoding
7
+ from models import basic, clusterkit, anchor_gen
8
+ from collections import OrderedDict
9
+ from utils import util, cielab
10
+
11
+
12
+ class SpixelSeg(nn.Module):
13
+ def __init__(self, inChannel=1, outChannel=9, batchNorm=True):
14
+ super(SpixelSeg, self).__init__()
15
+ self.net = SpixelNet(inChannel=inChannel, outChannel=outChannel, batchNorm=batchNorm)
16
+
17
+ def get_trainable_params(self, lr=1.0):
18
+ #print('=> [optimizer] finetune backbone with smaller lr')
19
+ params = []
20
+ for name, param in self.named_parameters():
21
+ if 'xxx' in name:
22
+ params.append({'params': param, 'lr': lr})
23
+ else:
24
+ params.append({'params': param})
25
+ return params
26
+
27
+ def forward(self, input_grays):
28
+ pred_probs = self.net(input_grays)
29
+ return pred_probs
30
+
31
+
32
+ class AnchorColorProb(nn.Module):
33
+ def __init__(self, inChannel=1, outChannel=313, sp_size=16, d_model=64, use_dense_pos=True, spix_pos=False, learning_pos=False, \
34
+ random_hint=False, hint2regress=False, enhanced=False, use_mask=False, rank=0, colorLabeler=None):
35
+ super(AnchorColorProb, self).__init__()
36
+ self.sp_size = sp_size
37
+ self.spix_pos = spix_pos
38
+ self.use_token_mask = use_mask
39
+ self.hint2regress = hint2regress
40
+ self.segnet = SpixelSeg(inChannel=1, outChannel=9, batchNorm=True)
41
+ self.repnet = ColorProbNet(inChannel=inChannel, outChannel=64)
42
+ self.enhanced = enhanced
43
+ if self.enhanced:
44
+ self.enhanceNet = HourGlass2(inChannel=64+1, outChannel=2, resNum=3, normLayer=nn.BatchNorm2d)
45
+
46
+ ## transformer architecture
47
+ self.n_vocab = 313
48
+ d_model, dim_feedforward, nhead = d_model, 4*d_model, 8
49
+ dropout, activation = 0.1, "relu"
50
+ n_enc_layers, n_dec_layers = 6, 6
51
+ enc_layer = EncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, use_dense_pos)
52
+ self.wildpath = TransformerEncoder(enc_layer, n_enc_layers, use_dense_pos)
53
+ self.hintpath = TransformerEncoder(enc_layer, n_enc_layers, use_dense_pos)
54
+ if self.spix_pos:
55
+ n_pos_x, n_pos_y = 256, 256
56
+ else:
57
+ n_pos_x, n_pos_y = 256//sp_size, 16//sp_size
58
+ self.pos_enc = build_position_encoding(d_model//2, n_pos_x, n_pos_y, is_learned=False)
59
+
60
+ self.mid_word_prj = nn.Linear(d_model, self.n_vocab, bias=False)
61
+ if self.hint2regress:
62
+ self.trg_word_emb = nn.Linear(d_model+2+1, d_model, bias=False)
63
+ self.trg_word_prj = nn.Linear(d_model, 2, bias=False)
64
+ else:
65
+ self.trg_word_emb = nn.Linear(d_model+self.n_vocab+1, d_model, bias=False)
66
+ self.trg_word_prj = nn.Linear(d_model, self.n_vocab, bias=False)
67
+
68
+ self.colorLabeler = colorLabeler
69
+ anchor_mode = 'random' if random_hint else 'clustering'
70
+ self.anchorGen = anchor_gen.AnchorAnalysis(mode=anchor_mode, colorLabeler=self.colorLabeler)
71
+ self._reset_parameters()
72
+
73
+ def _reset_parameters(self):
74
+ for p in self.parameters():
75
+ if p.dim() > 1:
76
+ nn.init.xavier_uniform_(p)
77
+
78
+ def load_and_froze_weight(self, checkpt_path):
79
+ data_dict = torch.load(checkpt_path, map_location=torch.device('cpu'))
80
+ '''
81
+ for param_tensor in data_dict['state_dict']:
82
+ print(param_tensor,'\t',data_dict['state_dict'][param_tensor].size())
83
+ '''
84
+ self.segnet.load_state_dict(data_dict['state_dict'])
85
+ for name, param in self.segnet.named_parameters():
86
+ param.requires_grad = False
87
+ self.segnet.eval()
88
+
89
+ def set_train(self):
90
+ ## running mode only affect certain modules, e.g. Dropout, BN, etc.
91
+ self.repnet.train()
92
+ self.wildpath.train()
93
+ self.hintpath.train()
94
+ if self.enhanced:
95
+ self.enhanceNet.train()
96
+
97
+ def get_entry_mask(self, mask_tensor):
98
+ if mask_tensor is None:
99
+ return None
100
+ ## flatten (N,1,H,W) to (N,HW)
101
+ return mask_tensor.flatten(1)
102
+
103
+ def forward(self, input_grays, input_colors, n_anchors=8, sampled_T=0):
104
+ '''
105
+ Notice: function was customized for inferece only
106
+ '''
107
+ affinity_map = self.segnet(input_grays)
108
+ pred_feats = self.repnet(input_grays)
109
+ if self.spix_pos:
110
+ full_pos_feats = self.pos_enc(pred_feats)
111
+ proxy_feats = torch.cat([pred_feats, input_colors, full_pos_feats], dim=1)
112
+ pooled_proxy_feats, conf_sum = basic.poolfeat(proxy_feats, affinity_map, self.sp_size, self.sp_size, True)
113
+ feat_tokens = pooled_proxy_feats[:,:64,:,:]
114
+ spix_colors = pooled_proxy_feats[:,64:66,:,:]
115
+ pos_feats = pooled_proxy_feats[:,66:,:,:]
116
+ else:
117
+ proxy_feats = torch.cat([pred_feats, input_colors], dim=1)
118
+ pooled_proxy_feats, conf_sum = basic.poolfeat(proxy_feats, affinity_map, self.sp_size, self.sp_size, True)
119
+ feat_tokens = pooled_proxy_feats[:,:64,:,:]
120
+ spix_colors = pooled_proxy_feats[:,64:,:,:]
121
+ pos_feats = self.pos_enc(feat_tokens)
122
+
123
+ token_labels = torch.max(self.colorLabeler.encode_ab2ind(spix_colors), dim=1, keepdim=True)[1]
124
+ spixel_sizes = basic.get_spixel_size(affinity_map, self.sp_size, self.sp_size)
125
+ all_one_map = torch.ones(spixel_sizes.shape, device=input_grays.device)
126
+ empty_entries = torch.where(spixel_sizes < 25/(self.sp_size**2), all_one_map, 1-all_one_map)
127
+ src_pad_mask = self.get_entry_mask(empty_entries) if self.use_token_mask else None
128
+ trg_pad_mask = src_pad_mask
129
+
130
+ ## parallel prob
131
+ N,C,H,W = feat_tokens.shape
132
+ ## (N,C,H,W) -> (HW,N,C)
133
+ src_pos_seq = pos_feats.flatten(2).permute(2, 0, 1)
134
+ src_seq = feat_tokens.flatten(2).permute(2, 0, 1)
135
+ ## color prob branch
136
+ enc_out, _ = self.wildpath(src_seq, src_pos_seq, src_pad_mask)
137
+ pal_logit = self.mid_word_prj(enc_out)
138
+ pal_logit = pal_logit.permute(1, 2, 0).view(N,self.n_vocab,H,W)
139
+
140
+ ## seed prob branch
141
+ ## mask(N,1,H,W): sample anchors at clustering layers
142
+ color_feat = enc_out.permute(1, 2, 0).view(N,C,H,W)
143
+ hint_mask, cluster_mask = self.anchorGen(color_feat, n_anchors, spixel_sizes, use_sklearn_kmeans=False)
144
+ pred_prob = torch.softmax(pal_logit, dim=1)
145
+ color_feat2 = src_seq.permute(1, 2, 0).view(N,C,H,W)
146
+ #pred_prob, adj_matrix = self.anchorGen._detect_correlation(color_feat, pred_prob, hint_mask, thres=0.1)
147
+ if sampled_T < 0:
148
+ ## GT anchor colors
149
+ sampled_spix_colors = spix_colors
150
+ elif sampled_T > 0:
151
+ top1_spix_colors = self.anchorGen._sample_anchor_colors(pred_prob, hint_mask, T=0)
152
+ top2_spix_colors = self.anchorGen._sample_anchor_colors(pred_prob, hint_mask, T=1)
153
+ top3_spix_colors = self.anchorGen._sample_anchor_colors(pred_prob, hint_mask, T=2)
154
+ ## duplicate meta tensors
155
+ sampled_spix_colors = torch.cat((top1_spix_colors,top2_spix_colors,top3_spix_colors), dim=0)
156
+ N = 3*N
157
+ input_grays = input_grays.expand(N,-1,-1,-1)
158
+ hint_mask = hint_mask.expand(N,-1,-1,-1)
159
+ affinity_map = affinity_map.expand(N,-1,-1,-1)
160
+ src_seq = src_seq.expand(-1, N,-1)
161
+ src_pos_seq = src_pos_seq.expand(-1, N,-1)
162
+ else:
163
+ sampled_spix_colors = self.anchorGen._sample_anchor_colors(pred_prob, hint_mask, T=sampled_T)
164
+ ## debug: controllable
165
+ if False:
166
+ hint_mask, sampled_spix_colors = basic.io_user_control(hint_mask, spix_colors, output=False)
167
+
168
+ sampled_token_labels = torch.max(self.colorLabeler.encode_ab2ind(sampled_spix_colors), dim=1, keepdim=True)[1]
169
+
170
+ ## hint based prediction
171
+ ## (N,C,H,W) -> (HW,N,C)
172
+ mask_seq = hint_mask.flatten(2).permute(2, 0, 1)
173
+ if self.hint2regress:
174
+ spix_colors_ = sampled_spix_colors
175
+ gt_seq = spix_colors_.flatten(2).permute(2, 0, 1)
176
+ hint_seq = self.trg_word_emb(torch.cat([src_seq, mask_seq * gt_seq, mask_seq], dim=2))
177
+ dec_out, _ = self.hintpath(hint_seq, src_pos_seq, src_pad_mask)
178
+ else:
179
+ token_labels_ = sampled_token_labels
180
+ label_map = F.one_hot(token_labels_, num_classes=313).squeeze(1).float()
181
+ label_seq = label_map.permute(0, 3, 1, 2).flatten(2).permute(2, 0, 1)
182
+ hint_seq = self.trg_word_emb(torch.cat([src_seq, mask_seq * label_seq, mask_seq], dim=2))
183
+ dec_out, _ = self.hintpath(hint_seq, src_pos_seq, src_pad_mask)
184
+ ref_logit = self.trg_word_prj(dec_out)
185
+ Ct = 2 if self.hint2regress else self.n_vocab
186
+ ref_logit = ref_logit.permute(1, 2, 0).view(N,Ct,H,W)
187
+
188
+ ## pixelwise enhancement
189
+ pred_colors = None
190
+ if self.enhanced:
191
+ proc_feats = dec_out.permute(1, 2, 0).view(N,64,H,W)
192
+ full_feats = basic.upfeat(proc_feats, affinity_map, self.sp_size, self.sp_size)
193
+ pred_colors = self.enhanceNet(torch.cat((input_grays,full_feats), dim=1))
194
+ pred_colors = torch.tanh(pred_colors)
195
+
196
+ return pal_logit, ref_logit, pred_colors, affinity_map, spix_colors, hint_mask
models/network.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import init
5
+ import torchvision
6
+ import torch.nn.utils.spectral_norm as spectral_norm
7
+ import math
8
+
9
+
10
+ class ConvBlock(nn.Module):
11
+ def __init__(self, inChannels, outChannels, convNum, normLayer=None):
12
+ super(ConvBlock, self).__init__()
13
+ self.inConv = nn.Sequential(
14
+ nn.Conv2d(inChannels, outChannels, kernel_size=3, padding=1),
15
+ nn.ReLU(inplace=True)
16
+ )
17
+ layers = []
18
+ for _ in range(convNum - 1):
19
+ layers.append(nn.Conv2d(outChannels, outChannels, kernel_size=3, padding=1))
20
+ layers.append(nn.ReLU(inplace=True))
21
+ if not (normLayer is None):
22
+ layers.append(normLayer(outChannels))
23
+ self.conv = nn.Sequential(*layers)
24
+
25
+ def forward(self, x):
26
+ x = self.inConv(x)
27
+ x = self.conv(x)
28
+ return x
29
+
30
+
31
+ class ResidualBlock(nn.Module):
32
+ def __init__(self, channels, normLayer=None):
33
+ super(ResidualBlock, self).__init__()
34
+ layers = []
35
+ layers.append(nn.Conv2d(channels, channels, kernel_size=3, padding=1))
36
+ layers.append(spectral_norm(nn.Conv2d(channels, channels, kernel_size=3, padding=1)))
37
+ if not (normLayer is None):
38
+ layers.append(normLayer(channels))
39
+ layers.append(nn.ReLU(inplace=True))
40
+ layers.append(nn.Conv2d(channels, channels, kernel_size=3, padding=1))
41
+ if not (normLayer is None):
42
+ layers.append(normLayer(channels))
43
+ self.conv = nn.Sequential(*layers)
44
+
45
+ def forward(self, x):
46
+ residual = self.conv(x)
47
+ return F.relu(x + residual, inplace=True)
48
+
49
+
50
+ class ResidualBlockSN(nn.Module):
51
+ def __init__(self, channels, normLayer=None):
52
+ super(ResidualBlockSN, self).__init__()
53
+ layers = []
54
+ layers.append(spectral_norm(nn.Conv2d(channels, channels, kernel_size=3, padding=1)))
55
+ layers.append(nn.LeakyReLU(0.2, True))
56
+ layers.append(spectral_norm(nn.Conv2d(channels, channels, kernel_size=3, padding=1)))
57
+ if not (normLayer is None):
58
+ layers.append(normLayer(channels))
59
+ self.conv = nn.Sequential(*layers)
60
+
61
+ def forward(self, x):
62
+ residual = self.conv(x)
63
+ return F.leaky_relu(x + residual, 2e-1, inplace=True)
64
+
65
+
66
+ class DownsampleBlock(nn.Module):
67
+ def __init__(self, inChannels, outChannels, convNum=2, normLayer=None):
68
+ super(DownsampleBlock, self).__init__()
69
+ layers = []
70
+ layers.append(nn.Conv2d(inChannels, outChannels, kernel_size=3, padding=1, stride=2))
71
+ layers.append(nn.ReLU(inplace=True))
72
+ for _ in range(convNum - 1):
73
+ layers.append(nn.Conv2d(outChannels, outChannels, kernel_size=3, padding=1))
74
+ layers.append(nn.ReLU(inplace=True))
75
+ if not (normLayer is None):
76
+ layers.append(normLayer(outChannels))
77
+ self.conv = nn.Sequential(*layers)
78
+
79
+ def forward(self, x):
80
+ return self.conv(x)
81
+
82
+
83
+ class UpsampleBlock(nn.Module):
84
+ def __init__(self, inChannels, outChannels, convNum=2, normLayer=None):
85
+ super(UpsampleBlock, self).__init__()
86
+ self.conv1 = nn.Conv2d(inChannels, outChannels, kernel_size=3, padding=1, stride=1)
87
+ self.combine = nn.Conv2d(2 * outChannels, outChannels, kernel_size=3, padding=1)
88
+ layers = []
89
+ for _ in range(convNum - 1):
90
+ layers.append(nn.Conv2d(outChannels, outChannels, kernel_size=3, padding=1))
91
+ layers.append(nn.ReLU(inplace=True))
92
+ if not (normLayer is None):
93
+ layers.append(normLayer(outChannels))
94
+ self.conv2 = nn.Sequential(*layers)
95
+
96
+ def forward(self, x, x0):
97
+ x = self.conv1(x)
98
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
99
+ x = self.combine(torch.cat((x, x0), 1))
100
+ x = F.relu(x)
101
+ return self.conv2(x)
102
+
103
+
104
+ class UpsampleBlockSN(nn.Module):
105
+ def __init__(self, inChannels, outChannels, convNum=2, normLayer=None):
106
+ super(UpsampleBlockSN, self).__init__()
107
+ self.conv1 = spectral_norm(nn.Conv2d(inChannels, outChannels, kernel_size=3, stride=1, padding=1))
108
+ self.shortcut = spectral_norm(nn.Conv2d(outChannels, outChannels, kernel_size=3, stride=1, padding=1))
109
+ layers = []
110
+ for _ in range(convNum - 1):
111
+ layers.append(spectral_norm(nn.Conv2d(outChannels, outChannels, kernel_size=3, padding=1)))
112
+ layers.append(nn.LeakyReLU(0.2, True))
113
+ if not (normLayer is None):
114
+ layers.append(normLayer(outChannels))
115
+ self.conv2 = nn.Sequential(*layers)
116
+
117
+ def forward(self, x, x0):
118
+ x = self.conv1(x)
119
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
120
+ x = x + self.shortcut(x0)
121
+ x = F.leaky_relu(x, 2e-1)
122
+ return self.conv2(x)
123
+
124
+
125
+ class HourGlass2(nn.Module):
126
+ def __init__(self, inChannel=3, outChannel=1, resNum=3, normLayer=None):
127
+ super(HourGlass2, self).__init__()
128
+ self.inConv = ConvBlock(inChannel, 64, convNum=2, normLayer=normLayer)
129
+ self.down1 = DownsampleBlock(64, 128, convNum=2, normLayer=normLayer)
130
+ self.down2 = DownsampleBlock(128, 256, convNum=2, normLayer=normLayer)
131
+ self.residual = nn.Sequential(*[ResidualBlock(256) for _ in range(resNum)])
132
+ self.up2 = UpsampleBlock(256, 128, convNum=3, normLayer=normLayer)
133
+ self.up1 = UpsampleBlock(128, 64, convNum=3, normLayer=normLayer)
134
+ self.outConv = nn.Conv2d(64, outChannel, kernel_size=3, padding=1)
135
+
136
+ def forward(self, x):
137
+ f1 = self.inConv(x)
138
+ f2 = self.down1(f1)
139
+ f3 = self.down2(f2)
140
+ r3 = self.residual(f3)
141
+ r2 = self.up2(r3, f2)
142
+ r1 = self.up1(r2, f1)
143
+ y = self.outConv(r1)
144
+ return y
145
+
146
+
147
+ class ColorProbNet(nn.Module):
148
+ def __init__(self, inChannel=1, outChannel=2, with_SA=False):
149
+ super(ColorProbNet, self).__init__()
150
+ BNFunc = nn.BatchNorm2d
151
+ # conv1: 256
152
+ conv1_2 = [spectral_norm(nn.Conv2d(inChannel, 64, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
153
+ conv1_2 += [spectral_norm(nn.Conv2d(64, 64, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
154
+ conv1_2 += [BNFunc(64, affine=True)]
155
+ # conv2: 128
156
+ conv2_3 = [spectral_norm(nn.Conv2d(64, 128, 3, stride=2, padding=1)), nn.LeakyReLU(0.2, True),]
157
+ conv2_3 += [spectral_norm(nn.Conv2d(128, 128, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
158
+ conv2_3 += [spectral_norm(nn.Conv2d(128, 128, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
159
+ conv2_3 += [BNFunc(128, affine=True)]
160
+ # conv3: 64
161
+ conv3_3 = [spectral_norm(nn.Conv2d(128, 256, 3, stride=2, padding=1)), nn.LeakyReLU(0.2, True),]
162
+ conv3_3 += [spectral_norm(nn.Conv2d(256, 256, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
163
+ conv3_3 += [spectral_norm(nn.Conv2d(256, 256, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
164
+ conv3_3 += [BNFunc(256, affine=True)]
165
+ # conv4: 32
166
+ conv4_3 = [spectral_norm(nn.Conv2d(256, 512, 3, stride=2, padding=1)), nn.LeakyReLU(0.2, True),]
167
+ conv4_3 += [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
168
+ conv4_3 += [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
169
+ conv4_3 += [BNFunc(512, affine=True)]
170
+ # conv5: 32
171
+ conv5_3 = [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
172
+ conv5_3 += [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
173
+ conv5_3 += [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
174
+ conv5_3 += [BNFunc(512, affine=True)]
175
+ # conv6: 32
176
+ conv6_3 = [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
177
+ conv6_3 += [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
178
+ conv6_3 += [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
179
+ conv6_3 += [BNFunc(512, affine=True),]
180
+ if with_SA:
181
+ conv6_3 += [Self_Attn(512)]
182
+ # conv7: 32
183
+ conv7_3 = [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
184
+ conv7_3 += [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
185
+ conv7_3 += [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
186
+ conv7_3 += [BNFunc(512, affine=True)]
187
+ # conv8: 64
188
+ conv8up = [nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(512, 256, 3, stride=1, padding=1),]
189
+ conv3short8 = [nn.Conv2d(256, 256, 3, stride=1, padding=1),]
190
+ conv8_3 = [nn.ReLU(True),]
191
+ conv8_3 += [nn.Conv2d(256, 256, 3, stride=1, padding=1), nn.ReLU(True),]
192
+ conv8_3 += [nn.Conv2d(256, 256, 3, stride=1, padding=1), nn.ReLU(True),]
193
+ conv8_3 += [BNFunc(256, affine=True),]
194
+ # conv9: 128
195
+ conv9up = [nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(256, 128, 3, stride=1, padding=1),]
196
+ conv9_2 = [nn.Conv2d(128, 128, 3, stride=1, padding=1), nn.ReLU(True),]
197
+ conv9_2 += [BNFunc(128, affine=True)]
198
+ # conv10: 64
199
+ conv10up = [nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(128, 64, 3, stride=1, padding=1),]
200
+ conv10_2 = [nn.ReLU(True),]
201
+ conv10_2 += [nn.Conv2d(64, outChannel, 3, stride=1, padding=1), nn.ReLU(True),]
202
+
203
+ self.conv1_2 = nn.Sequential(*conv1_2)
204
+ self.conv2_3 = nn.Sequential(*conv2_3)
205
+ self.conv3_3 = nn.Sequential(*conv3_3)
206
+ self.conv4_3 = nn.Sequential(*conv4_3)
207
+ self.conv5_3 = nn.Sequential(*conv5_3)
208
+ self.conv6_3 = nn.Sequential(*conv6_3)
209
+ self.conv7_3 = nn.Sequential(*conv7_3)
210
+ self.conv8up = nn.Sequential(*conv8up)
211
+ self.conv3short8 = nn.Sequential(*conv3short8)
212
+ self.conv8_3 = nn.Sequential(*conv8_3)
213
+ self.conv9up = nn.Sequential(*conv9up)
214
+ self.conv9_2 = nn.Sequential(*conv9_2)
215
+ self.conv10up = nn.Sequential(*conv10up)
216
+ self.conv10_2 = nn.Sequential(*conv10_2)
217
+ # claffificaton output
218
+ #self.model_class = nn.Sequential(*[nn.Conv2d(256, 313, kernel_size=1, padding=0, stride=1),])
219
+
220
+ def forward(self, input_grays):
221
+ f1_2 = self.conv1_2(input_grays)
222
+ f2_3 = self.conv2_3(f1_2)
223
+ f3_3 = self.conv3_3(f2_3)
224
+ f4_3 = self.conv4_3(f3_3)
225
+ f5_3 = self.conv5_3(f4_3)
226
+ f6_3 = self.conv6_3(f5_3)
227
+ f7_3 = self.conv7_3(f6_3)
228
+ f8_up = self.conv8up(f7_3) + self.conv3short8(f3_3)
229
+ f8_3 = self.conv8_3(f8_up)
230
+ f9_up = self.conv9up(f8_3)
231
+ f9_2 = self.conv9_2(f9_up)
232
+ f10_up = self.conv10up(f9_2)
233
+ f10_2 = self.conv10_2(f10_up)
234
+ out_feats = f10_2
235
+ #out_probs = self.model_class(f8_3)
236
+ return out_feats
237
+
238
+
239
+
240
+ def conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1):
241
+ if batchNorm:
242
+ return nn.Sequential(
243
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=False),
244
+ nn.BatchNorm2d(out_planes),
245
+ nn.LeakyReLU(0.1)
246
+ )
247
+ else:
248
+ return nn.Sequential(
249
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True),
250
+ nn.LeakyReLU(0.1)
251
+ )
252
+
253
+
254
+ def deconv(in_planes, out_planes):
255
+ return nn.Sequential(
256
+ nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True),
257
+ nn.LeakyReLU(0.1)
258
+ )
259
+
260
+ class SpixelNet(nn.Module):
261
+ def __init__(self, inChannel=3, outChannel=9, batchNorm=True):
262
+ super(SpixelNet,self).__init__()
263
+ self.batchNorm = batchNorm
264
+ self.conv0a = conv(self.batchNorm, inChannel, 16, kernel_size=3)
265
+ self.conv0b = conv(self.batchNorm, 16, 16, kernel_size=3)
266
+ self.conv1a = conv(self.batchNorm, 16, 32, kernel_size=3, stride=2)
267
+ self.conv1b = conv(self.batchNorm, 32, 32, kernel_size=3)
268
+ self.conv2a = conv(self.batchNorm, 32, 64, kernel_size=3, stride=2)
269
+ self.conv2b = conv(self.batchNorm, 64, 64, kernel_size=3)
270
+ self.conv3a = conv(self.batchNorm, 64, 128, kernel_size=3, stride=2)
271
+ self.conv3b = conv(self.batchNorm, 128, 128, kernel_size=3)
272
+ self.conv4a = conv(self.batchNorm, 128, 256, kernel_size=3, stride=2)
273
+ self.conv4b = conv(self.batchNorm, 256, 256, kernel_size=3)
274
+ self.deconv3 = deconv(256, 128)
275
+ self.conv3_1 = conv(self.batchNorm, 256, 128)
276
+ self.deconv2 = deconv(128, 64)
277
+ self.conv2_1 = conv(self.batchNorm, 128, 64)
278
+ self.deconv1 = deconv(64, 32)
279
+ self.conv1_1 = conv(self.batchNorm, 64, 32)
280
+ self.deconv0 = deconv(32, 16)
281
+ self.conv0_1 = conv(self.batchNorm, 32, 16)
282
+ self.pred_mask0 = nn.Conv2d(16, outChannel, kernel_size=3, stride=1, padding=1, bias=True)
283
+ self.softmax = nn.Softmax(1)
284
+ for m in self.modules():
285
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
286
+ init.kaiming_normal_(m.weight, 0.1)
287
+ if m.bias is not None:
288
+ init.constant_(m.bias, 0)
289
+ elif isinstance(m, nn.BatchNorm2d):
290
+ init.constant_(m.weight, 1)
291
+ init.constant_(m.bias, 0)
292
+
293
+ def forward(self, x):
294
+ out1 = self.conv0b(self.conv0a(x)) #5*5
295
+ out2 = self.conv1b(self.conv1a(out1)) #11*11
296
+ out3 = self.conv2b(self.conv2a(out2)) #23*23
297
+ out4 = self.conv3b(self.conv3a(out3)) #47*47
298
+ out5 = self.conv4b(self.conv4a(out4)) #95*95
299
+ out_deconv3 = self.deconv3(out5)
300
+ concat3 = torch.cat((out4, out_deconv3), 1)
301
+ out_conv3_1 = self.conv3_1(concat3)
302
+ out_deconv2 = self.deconv2(out_conv3_1)
303
+ concat2 = torch.cat((out3, out_deconv2), 1)
304
+ out_conv2_1 = self.conv2_1(concat2)
305
+ out_deconv1 = self.deconv1(out_conv2_1)
306
+ concat1 = torch.cat((out2, out_deconv1), 1)
307
+ out_conv1_1 = self.conv1_1(concat1)
308
+ out_deconv0 = self.deconv0(out_conv1_1)
309
+ concat0 = torch.cat((out1, out_deconv0), 1)
310
+ out_conv0_1 = self.conv0_1(concat0)
311
+ mask0 = self.pred_mask0(out_conv0_1)
312
+ prob0 = self.softmax(mask0)
313
+ return prob0
314
+
315
+
316
+
317
+ ## VGG architecter, used for the perceptual loss using a pretrained VGG network
318
+ class VGG19(torch.nn.Module):
319
+ def __init__(self, requires_grad=False, local_pretrained_path='checkpoints/vgg19.pth'):
320
+ super().__init__()
321
+ #vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
322
+ model = torchvision.models.vgg19()
323
+ model.load_state_dict(torch.load(local_pretrained_path))
324
+ vgg_pretrained_features = model.features
325
+
326
+ self.slice1 = torch.nn.Sequential()
327
+ self.slice2 = torch.nn.Sequential()
328
+ self.slice3 = torch.nn.Sequential()
329
+ self.slice4 = torch.nn.Sequential()
330
+ self.slice5 = torch.nn.Sequential()
331
+ for x in range(2):
332
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
333
+ for x in range(2, 7):
334
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
335
+ for x in range(7, 12):
336
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
337
+ for x in range(12, 21):
338
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
339
+ for x in range(21, 30):
340
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
341
+ if not requires_grad:
342
+ for param in self.parameters():
343
+ param.requires_grad = False
344
+
345
+ def forward(self, X):
346
+ h_relu1 = self.slice1(X)
347
+ h_relu2 = self.slice2(h_relu1)
348
+ h_relu3 = self.slice3(h_relu2)
349
+ h_relu4 = self.slice4(h_relu3)
350
+ h_relu5 = self.slice5(h_relu4)
351
+ out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
352
+ return out
models/position_encoding.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Various positional encodings for the transformer.
4
+ """
5
+ import math
6
+ import torch
7
+ from torch import nn
8
+
9
+
10
+ class PositionEmbeddingSine(nn.Module):
11
+ """
12
+ This is a more standard version of the position embedding, very similar to the one
13
+ used by the Attention is all you need paper, generalized to work on images.
14
+ """
15
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
16
+ super().__init__()
17
+ self.num_pos_feats = num_pos_feats
18
+ self.temperature = temperature
19
+ self.normalize = normalize
20
+ if scale is not None and normalize is False:
21
+ raise ValueError("normalize should be True if scale is passed")
22
+ if scale is None:
23
+ scale = 2 * math.pi
24
+ self.scale = scale
25
+
26
+ def forward(self, token_tensors):
27
+ ## input: (B,C,H,W)
28
+ x = token_tensors
29
+ h, w = x.shape[-2:]
30
+ identity_map= torch.ones((h,w), device=x.device)
31
+ y_embed = identity_map.cumsum(0, dtype=torch.float32)
32
+ x_embed = identity_map.cumsum(1, dtype=torch.float32)
33
+ if self.normalize:
34
+ eps = 1e-6
35
+ y_embed = y_embed / (y_embed[-1:, :] + eps) * self.scale
36
+ x_embed = x_embed / (x_embed[:, -1:] + eps) * self.scale
37
+
38
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
39
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
40
+
41
+ pos_x = x_embed[:, :, None] / dim_t
42
+ pos_y = y_embed[:, :, None] / dim_t
43
+ pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
44
+ pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
45
+ pos = torch.cat((pos_y, pos_x), dim=2).permute(2, 0, 1)
46
+ batch_pos = pos.unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
47
+ return batch_pos
48
+
49
+
50
+ class PositionEmbeddingLearned(nn.Module):
51
+ """
52
+ Absolute pos embedding, learned.
53
+ """
54
+ def __init__(self, n_pos_x=16, n_pos_y=16, num_pos_feats=64):
55
+ super().__init__()
56
+ self.row_embed = nn.Embedding(n_pos_y, num_pos_feats)
57
+ self.col_embed = nn.Embedding(n_pos_x, num_pos_feats)
58
+ self.reset_parameters()
59
+
60
+ def reset_parameters(self):
61
+ nn.init.uniform_(self.row_embed.weight)
62
+ nn.init.uniform_(self.col_embed.weight)
63
+
64
+ def forward(self, token_tensors):
65
+ ## input: (B,C,H,W)
66
+ x = token_tensors
67
+ h, w = x.shape[-2:]
68
+ i = torch.arange(w, device=x.device)
69
+ j = torch.arange(h, device=x.device)
70
+ x_emb = self.col_embed(i)
71
+ y_emb = self.row_embed(j)
72
+ pos = torch.cat([
73
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
74
+ y_emb.unsqueeze(1).repeat(1, w, 1),
75
+ ], dim=-1).permute(2, 0, 1)
76
+ batch_pos = pos.unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
77
+ return batch_pos
78
+
79
+
80
+ def build_position_encoding(num_pos_feats=64, n_pos_x=16, n_pos_y=16, is_learned=False):
81
+ if is_learned:
82
+ position_embedding = PositionEmbeddingLearned(n_pos_x, n_pos_y, num_pos_feats)
83
+ else:
84
+ position_embedding = PositionEmbeddingSine(num_pos_feats, normalize=True)
85
+
86
+ return position_embedding
models/transformer2d.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ import copy, math
5
+ from models.position_encoding import build_position_encoding
6
+
7
+
8
+ class TransformerEncoder(nn.Module):
9
+
10
+ def __init__(self, enc_layer, num_layers, use_dense_pos=False):
11
+ super().__init__()
12
+ self.layers = nn.ModuleList([copy.deepcopy(enc_layer) for i in range(num_layers)])
13
+ self.num_layers = num_layers
14
+ self.use_dense_pos = use_dense_pos
15
+
16
+ def forward(self, src, pos, padding_mask=None):
17
+ if self.use_dense_pos:
18
+ ## pos encoding at each MH-Attention block (q,k)
19
+ output, pos_enc = src, pos
20
+ for layer in self.layers:
21
+ output, att_map = layer(output, pos_enc, padding_mask)
22
+ else:
23
+ ## pos encoding at input only (q,k,v)
24
+ output, pos_enc = src + pos, None
25
+ for layer in self.layers:
26
+ output, att_map = layer(output, pos_enc, padding_mask)
27
+ return output, att_map
28
+
29
+
30
+ class EncoderLayer(nn.Module):
31
+
32
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
33
+ use_dense_pos=False):
34
+ super().__init__()
35
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
36
+ # Implementation of Feedforward model
37
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
38
+ self.dropout = nn.Dropout(dropout)
39
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
40
+
41
+ self.norm1 = nn.LayerNorm(d_model)
42
+ self.norm2 = nn.LayerNorm(d_model)
43
+ self.dropout1 = nn.Dropout(dropout)
44
+ self.dropout2 = nn.Dropout(dropout)
45
+
46
+ self.activation = _get_activation_fn(activation)
47
+
48
+ def with_pos_embed(self, tensor, pos):
49
+ return tensor if pos is None else tensor + pos
50
+
51
+ def forward(self, src, pos, padding_mask):
52
+ q = k = self.with_pos_embed(src, pos)
53
+ src2, attn = self.self_attn(q, k, value=src, key_padding_mask=padding_mask)
54
+ src = src + self.dropout1(src2)
55
+ src = self.norm1(src)
56
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
57
+ src = src + self.dropout2(src2)
58
+ src = self.norm2(src)
59
+ return src, attn
60
+
61
+
62
+ class TransformerDecoder(nn.Module):
63
+
64
+ def __init__(self, dec_layer, num_layers, use_dense_pos=False, return_intermediate=False):
65
+ super().__init__()
66
+ self.layers = nn.ModuleList([copy.deepcopy(dec_layer) for i in range(num_layers)])
67
+ self.num_layers = num_layers
68
+ self.use_dense_pos = use_dense_pos
69
+ self.return_intermediate = return_intermediate
70
+
71
+ def forward(self, tgt, tgt_pos, memory, memory_pos,
72
+ tgt_padding_mask, src_padding_mask, tgt_attn_mask=None):
73
+ intermediate = []
74
+ if self.use_dense_pos:
75
+ ## pos encoding at each MH-Attention block (q,k)
76
+ output = tgt
77
+ tgt_pos_enc, memory_pos_enc = tgt_pos, memory_pos
78
+ for layer in self.layers:
79
+ output, att_map = layer(output, tgt_pos_enc, memory, memory_pos_enc,
80
+ tgt_padding_mask, src_padding_mask, tgt_attn_mask)
81
+ if self.return_intermediate:
82
+ intermediate.append(output)
83
+ else:
84
+ ## pos encoding at input only (q,k,v)
85
+ output = tgt + tgt_pos
86
+ tgt_pos_enc, memory_pos_enc = None, None
87
+ for layer in self.layers:
88
+ output, att_map = layer(output, tgt_pos_enc, memory, memory_pos_enc,
89
+ tgt_padding_mask, src_padding_mask, tgt_attn_mask)
90
+ if self.return_intermediate:
91
+ intermediate.append(output)
92
+
93
+ if self.return_intermediate:
94
+ return torch.stack(intermediate)
95
+ return output, att_map
96
+
97
+
98
+ class DecoderLayer(nn.Module):
99
+
100
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
101
+ use_dense_pos=False):
102
+ super().__init__()
103
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
104
+ self.corr_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
105
+ # Implementation of Feedforward model
106
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
107
+ self.dropout = nn.Dropout(dropout)
108
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
109
+
110
+ self.norm1 = nn.LayerNorm(d_model)
111
+ self.norm2 = nn.LayerNorm(d_model)
112
+ self.norm3 = nn.LayerNorm(d_model)
113
+ self.dropout1 = nn.Dropout(dropout)
114
+ self.dropout2 = nn.Dropout(dropout)
115
+ self.dropout3 = nn.Dropout(dropout)
116
+
117
+ self.activation = _get_activation_fn(activation)
118
+
119
+ def with_pos_embed(self, tensor, pos):
120
+ return tensor if pos is None else tensor + pos
121
+
122
+ def forward(self, tgt, tgt_pos, memory, memory_pos,
123
+ tgt_padding_mask, memory_padding_mask, tgt_attn_mask):
124
+ q = k = self.with_pos_embed(tgt, tgt_pos)
125
+ tgt2, attn = self.self_attn(q, k, value=tgt, key_padding_mask=tgt_padding_mask,
126
+ attn_mask=tgt_attn_mask)
127
+ tgt = tgt + self.dropout1(tgt2)
128
+ tgt = self.norm1(tgt)
129
+ tgt2, attn = self.corr_attn(query=self.with_pos_embed(tgt, tgt_pos),
130
+ key=self.with_pos_embed(memory, memory_pos),
131
+ value=memory, key_padding_mask=memory_padding_mask)
132
+ tgt = tgt + self.dropout2(tgt2)
133
+ tgt = self.norm2(tgt)
134
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
135
+ tgt = tgt + self.dropout3(tgt2)
136
+ tgt = self.norm3(tgt)
137
+ return tgt, attn
138
+
139
+
140
+ def _get_activation_fn(activation):
141
+ """Return an activation function given a string"""
142
+ if activation == "relu":
143
+ return F.relu
144
+ if activation == "gelu":
145
+ return F.gelu
146
+ if activation == "glu":
147
+ return F.glu
148
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
149
+
150
+
151
+
152
+ #-----------------------------------------------------------------------------------
153
+ '''
154
+ copy from the implementatoin of "attention-is-all-you-need-pytorch-master" by Yu-Hsiang Huang
155
+ '''
156
+
157
+ class MultiHeadAttention(nn.Module):
158
+ ''' Multi-Head Attention module '''
159
+
160
+ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
161
+ super().__init__()
162
+
163
+ self.n_head = n_head
164
+ self.d_k = d_k
165
+ self.d_v = d_v
166
+
167
+ self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
168
+ self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
169
+ self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
170
+ self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
171
+
172
+ self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)
173
+
174
+ self.dropout = nn.Dropout(dropout)
175
+ self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
176
+
177
+
178
+ def forward(self, q, k, v, mask=None):
179
+
180
+ d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
181
+ sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
182
+
183
+ residual = q
184
+
185
+ # Pass through the pre-attention projection: b x lq x (n*dv)
186
+ # Separate different heads: b x lq x n x dv
187
+ q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
188
+ k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
189
+ v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
190
+
191
+ # Transpose for attention dot product: b x n x lq x dv
192
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
193
+
194
+ if mask is not None:
195
+ mask = mask.unsqueeze(1) # For head axis broadcasting.
196
+
197
+ q, attn = self.attention(q, k, v, mask=mask)
198
+
199
+ # Transpose to move the head dimension back: b x lq x n x dv
200
+ # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
201
+ q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
202
+ q = self.dropout(self.fc(q))
203
+ q += residual
204
+
205
+ q = self.layer_norm(q)
206
+
207
+ return q, attn
208
+
209
+
210
+
211
+ class ScaledDotProductAttention(nn.Module):
212
+ ''' Scaled Dot-Product Attention '''
213
+
214
+ def __init__(self, temperature, attn_dropout=0.1):
215
+ super().__init__()
216
+ self.temperature = temperature
217
+ self.dropout = nn.Dropout(attn_dropout)
218
+
219
+ def forward(self, q, k, v, mask=None):
220
+
221
+ attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
222
+
223
+ if mask is not None:
224
+ attn = attn.masked_fill(mask == 0, -1e9)
225
+
226
+ attn = self.dropout(F.softmax(attn, dim=-1))
227
+ output = torch.matmul(attn, v)
228
+
229
+ return output, attn
predict.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prediction interface for Cog ⚙️
2
+ # https://github.com/replicate/cog/blob/main/docs/python.md
3
+
4
+ from cog import BasePredictor, Input, Path
5
+ import tempfile
6
+ import os, glob
7
+ import numpy as np
8
+ import cv2
9
+ from PIL import Image
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from models import model, basic
14
+ from utils import util
15
+
16
+ class Predictor(BasePredictor):
17
+ def setup(self):
18
+ seed = 130
19
+ np.random.seed(seed)
20
+ torch.manual_seed(seed)
21
+ torch.cuda.manual_seed(seed)
22
+ #print('--------------', torch.cuda.is_available())
23
+ """Load the model into memory to make running multiple predictions efficient"""
24
+ self.colorizer = model.AnchorColorProb(inChannel=1, outChannel=313, enhanced=True)
25
+ self.colorizer = self.colorizer.cuda()
26
+ checkpt_path = "./checkpoints/disco-beta.pth.rar"
27
+ assert os.path.exists(checkpt_path)
28
+ data_dict = torch.load(checkpt_path, map_location=torch.device('cpu'))
29
+ self.colorizer.load_state_dict(data_dict['state_dict'])
30
+ self.colorizer.eval()
31
+ self.color_class = basic.ColorLabel(lambda_=0.5, device='cuda')
32
+
33
+ def resize_ab2l(self, gray_img, lab_imgs):
34
+ H, W = gray_img.shape[:2]
35
+ reszied_ab = cv2.resize(lab_imgs[:,:,1:], (W,H), interpolation=cv2.INTER_LINEAR)
36
+ return np.concatenate((gray_img, reszied_ab), axis=2)
37
+
38
+ def predict(
39
+ self,
40
+ image: Path = Input(description="input image. Output will be one or multiple colorized images."),
41
+ n_anchors: int = Input(
42
+ description="number of color anchors", ge=3, le=14, default=8
43
+ ),
44
+ multi_result: bool = Input(
45
+ description="to generate diverse results", default=False
46
+ ),
47
+ vis_anchors: bool = Input(
48
+ description="to visualize the anchor locations", default=False
49
+ )
50
+ ) -> Path:
51
+ """Run a single prediction on the model"""
52
+ bgr_img = cv2.imread(str(image), cv2.IMREAD_COLOR)
53
+ rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
54
+ rgb_img = np.array(rgb_img / 255., np.float32)
55
+ lab_img = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2LAB)
56
+ org_grays = (lab_img[:,:,[0]]-50.) / 50.
57
+ lab_img = cv2.resize(lab_img, (256,256), interpolation=cv2.INTER_LINEAR)
58
+
59
+ lab_img = torch.from_numpy(lab_img.transpose((2, 0, 1)))
60
+ gray_img = (lab_img[0:1,:,:]-50.) / 50.
61
+ ab_chans = lab_img[1:3,:,:] / 110.
62
+ input_grays = gray_img.unsqueeze(0)
63
+ input_colors = ab_chans.unsqueeze(0)
64
+ input_grays = input_grays.cuda(non_blocking=True)
65
+ input_colors = input_colors.cuda(non_blocking=True)
66
+
67
+ sampled_T = 2 if multi_result else 0
68
+ pal_logit, ref_logit, enhanced_ab, affinity_map, spix_colors, hint_mask = self.colorizer(input_grays, \
69
+ input_colors, n_anchors, True, sampled_T)
70
+ pred_probs = pal_logit
71
+ guided_colors = self.color_class.decode_ind2ab(ref_logit, T=0)
72
+ sp_size = 16
73
+ guided_colors = basic.upfeat(guided_colors, affinity_map, sp_size, sp_size)
74
+ res_list = []
75
+ if multi_result:
76
+ for no in range(3):
77
+ pred_labs = torch.cat((input_grays,enhanced_ab[no:no+1,:,:,:]), dim=1)
78
+ lab_imgs = basic.tensor2array(pred_labs).squeeze(axis=0)
79
+ lab_imgs = self.resize_ab2l(org_grays, lab_imgs)
80
+ #util.save_normLabs_from_batch(lab_imgs, save_dir, [file_name], -1, suffix='c%d'%no)
81
+ res_list.append(lab_imgs)
82
+ else:
83
+ pred_labs = torch.cat((input_grays,enhanced_ab), dim=1)
84
+ lab_imgs = basic.tensor2array(pred_labs).squeeze(axis=0)
85
+ lab_imgs = self.resize_ab2l(org_grays, lab_imgs)
86
+ #util.save_normLabs_from_batch(lab_imgs, save_dir, [file_name], -1)#, suffix='enhanced')
87
+ res_list.append(lab_imgs)
88
+
89
+ if vis_anchors:
90
+ ## visualize anchor locations
91
+ anchor_masks = basic.upfeat(hint_mask, affinity_map, sp_size, sp_size)
92
+ marked_labs = basic.mark_color_hints(input_grays, enhanced_ab, anchor_masks, base_ABs=enhanced_ab)
93
+ hint_imgs = basic.tensor2array(marked_labs).squeeze(axis=0)
94
+ hint_imgs = self.resize_ab2l(org_grays, hint_imgs)
95
+ #util.save_normLabs_from_batch(hint_imgs, save_dir, [file_name], -1, suffix='anchors')
96
+ res_list.append(hint_imgs)
97
+
98
+ output = cv2.vconcat(res_list)
99
+ output[:,:,0] = output[:,:,0] * 50.0 + 50.0
100
+ output[:,:,1:3] = output[:,:,1:3] * 110.0
101
+ rgb_output = cv2.cvtColor(output[:,:,:], cv2.COLOR_LAB2BGR)
102
+ out_path = Path(tempfile.mkdtemp()) / "out.png"
103
+ cv2.imwrite(str(out_path), (rgb_output*255.0).astype(np.uint8))
104
+ return out_path
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ addict
2
+ future
3
+ numpy
4
+ opencv-python
5
+ pandas
6
+ Pillow
7
+ pyyaml
8
+ requests
9
+ scikit-image
10
+ scikit-learn
11
+ scipy
12
+ torch>=1.8.0
13
+ torchvision
14
+ tensorboardx>=2.4
15
+ tqdm
16
+ yapf
17
+ lpips
utils/__init__.py ADDED
File without changes
utils/cielab.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import numpy as np
3
+
4
+ class ABGamut:
5
+ RESOURCE_POINTS = "./utils/gamut_pts.npy"
6
+ RESOURCE_PRIOR = "./utils/gamut_probs.npy"
7
+ DTYPE = np.float32
8
+ EXPECTED_SIZE = 313
9
+ def __init__(self):
10
+ self.points = np.load(self.RESOURCE_POINTS).astype(self.DTYPE)
11
+ self.prior = np.load(self.RESOURCE_PRIOR).astype(self.DTYPE)
12
+ assert self.points.shape == (self.EXPECTED_SIZE, 2)
13
+ assert self.prior.shape == (self.EXPECTED_SIZE,)
14
+
15
+
16
+ class CIELAB:
17
+ L_MEAN = 50
18
+ AB_BINSIZE = 10
19
+ AB_RANGE = [-110 - AB_BINSIZE // 2, 110 + AB_BINSIZE // 2, AB_BINSIZE]
20
+ AB_DTYPE = np.float32
21
+ Q_DTYPE = np.int64
22
+
23
+ RGB_RESOLUTION = 101
24
+ RGB_RANGE = [0, 1, RGB_RESOLUTION]
25
+ RGB_DTYPE = np.float64
26
+
27
+ def __init__(self, gamut=None):
28
+ self.gamut = gamut if gamut is not None else ABGamut()
29
+ a, b, self.ab = self._get_ab()
30
+ self.ab_gamut_mask = self._get_ab_gamut_mask(
31
+ a, b, self.ab, self.gamut)
32
+
33
+ self.ab_to_q = self._get_ab_to_q(self.ab_gamut_mask)
34
+ self.q_to_ab = self._get_q_to_ab(self.ab, self.ab_gamut_mask)
35
+
36
+ @classmethod
37
+ def _get_ab(cls):
38
+ a = np.arange(*cls.AB_RANGE, dtype=cls.AB_DTYPE)
39
+ b = np.arange(*cls.AB_RANGE, dtype=cls.AB_DTYPE)
40
+ b_, a_ = np.meshgrid(a, b)
41
+ ab = np.dstack((a_, b_))
42
+ return a, b, ab
43
+
44
+ @classmethod
45
+ def _get_ab_gamut_mask(cls, a, b, ab, gamut):
46
+ ab_gamut_mask = np.full(ab.shape[:-1], False, dtype=bool)
47
+ a = np.digitize(gamut.points[:, 0], a) - 1
48
+ b = np.digitize(gamut.points[:, 1], b) - 1
49
+ for a_, b_ in zip(a, b):
50
+ ab_gamut_mask[a_, b_] = True
51
+
52
+ return ab_gamut_mask
53
+
54
+ @classmethod
55
+ def _get_ab_to_q(cls, ab_gamut_mask):
56
+ ab_to_q = np.full(ab_gamut_mask.shape, -1, dtype=cls.Q_DTYPE)
57
+ ab_to_q[ab_gamut_mask] = np.arange(np.count_nonzero(ab_gamut_mask))
58
+
59
+ return ab_to_q
60
+
61
+ @classmethod
62
+ def _get_q_to_ab(cls, ab, ab_gamut_mask):
63
+ return ab[ab_gamut_mask] + cls.AB_BINSIZE / 2
64
+
65
+ def bin_ab(self, ab):
66
+ ab_discrete = ((ab + 110) / self.AB_RANGE[2]).astype(int)
67
+
68
+ a, b = np.hsplit(ab_discrete.reshape(-1, 2), 2)
69
+
70
+ return self.ab_to_q[a, b].reshape(*ab.shape[:2])
71
+
utils/dataset_lab.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, division
2
+ import torch, os, glob
3
+ from torch.utils.data import Dataset, DataLoader
4
+ import numpy as np
5
+ from PIL import Image
6
+ import cv2
7
+
8
+
9
+ class LabDataset(Dataset):
10
+
11
+ def __init__(self, rootdir=None, filelist=None, resize=None):
12
+
13
+ if filelist:
14
+ self.file_list = filelist
15
+ else:
16
+ assert os.path.exists(rootdir), "@dir:'%s' NOT exist ..."%rootdir
17
+ self.file_list = glob.glob(os.path.join(rootdir, '*.*'))
18
+ self.file_list.sort()
19
+ self.resize = resize
20
+
21
+ def __len__(self):
22
+ return len(self.file_list)
23
+
24
+ def __getitem__(self, idx):
25
+ bgr_img = cv2.imread(self.file_list[idx], cv2.IMREAD_COLOR)
26
+ if self.resize:
27
+ bgr_img = cv2.resize(bgr_img, (self.resize,self.resize), interpolation=cv2.INTER_CUBIC)
28
+ bgr_img = np.array(bgr_img / 255., np.float32)
29
+ lab_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2LAB)
30
+ #print('--------L:', np.min(lab_img[:,:,0]), np.max(lab_img[:,:,0]))
31
+ #print('--------ab:', np.min(lab_img[:,:,1:3]), np.max(lab_img[:,:,1:3]))
32
+ lab_img = torch.from_numpy(lab_img.transpose((2, 0, 1)))
33
+ bgr_img = torch.from_numpy(bgr_img.transpose((2, 0, 1)))
34
+ gray_img = (lab_img[0:1,:,:]-50.) / 50.
35
+ color_map = lab_img[1:3,:,:] / 110.
36
+ bgr_img = bgr_img*2. - 1.
37
+ return {'gray': gray_img, 'color': color_map, 'BGR': bgr_img}
utils/gamut_probs.npy ADDED
Binary file (2.58 kB). View file
 
utils/gamut_pts.npy ADDED
Binary file (5.09 kB). View file
 
utils/util.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ from __future__ import print_function
3
+ import os, glob, shutil, math, json
4
+ from queue import Queue
5
+ from threading import Thread
6
+ from skimage.segmentation import mark_boundaries
7
+ import numpy as np
8
+ from PIL import Image
9
+ import cv2, torch
10
+
11
+ def get_gauss_kernel(size, sigma):
12
+ '''Function to mimic the 'fspecial' gaussian MATLAB function'''
13
+ x, y = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1]
14
+ g = np.exp(-((x**2 + y**2)/(2.0*sigma**2)))
15
+ return g/g.sum()
16
+
17
+
18
+ def batchGray2Colormap(gray_batch):
19
+ colormap = plt.get_cmap('viridis')
20
+ heatmap_batch = []
21
+ for i in range(gray_batch.shape[0]):
22
+ # quantize [-1,1] to {0,1}
23
+ gray_map = gray_batch[i, :, :, 0]
24
+ heatmap = (colormap(gray_map) * 2**16).astype(np.uint16)[:,:,:3]
25
+ heatmap_batch.append(heatmap/127.5-1.0)
26
+ return np.array(heatmap_batch)
27
+
28
+
29
+ class PlotterThread():
30
+ '''log tensorboard data in a background thread to save time'''
31
+ def __init__(self, writer):
32
+ self.writer = writer
33
+ self.task_queue = Queue(maxsize=0)
34
+ worker = Thread(target=self.do_work, args=(self.task_queue,))
35
+ worker.setDaemon(True)
36
+ worker.start()
37
+
38
+ def do_work(self, q):
39
+ while True:
40
+ content = q.get()
41
+ if content[-1] == 'image':
42
+ self.writer.add_image(*content[:-1])
43
+ elif content[-1] == 'scalar':
44
+ self.writer.add_scalar(*content[:-1])
45
+ else:
46
+ raise ValueError
47
+ q.task_done()
48
+
49
+ def add_data(self, name, value, step, data_type='scalar'):
50
+ self.task_queue.put([name, value, step, data_type])
51
+
52
+ def __len__(self):
53
+ return self.task_queue.qsize()
54
+
55
+
56
+ def save_images_from_batch(img_batch, save_dir, filename_list, batch_no=-1, suffix=None):
57
+ N,H,W,C = img_batch.shape
58
+ if C == 3:
59
+ #! rgb color image
60
+ for i in range(N):
61
+ # [-1,1] >>> [0,255]
62
+ image = Image.fromarray((127.5*(img_batch[i,:,:,:]+1.)).astype(np.uint8))
63
+ save_name = filename_list[i] if batch_no==-1 else '%05d.png' % (batch_no*N+i)
64
+ save_name = save_name.replace('.png', '-%s.png'%suffix) if suffix else save_name
65
+ image.save(os.path.join(save_dir, save_name), 'PNG')
66
+ elif C == 1:
67
+ #! single-channel gray image
68
+ for i in range(N):
69
+ # [-1,1] >>> [0,255]
70
+ image = Image.fromarray((127.5*(img_batch[i,:,:,0]+1.)).astype(np.uint8))
71
+ save_name = filename_list[i] if batch_no==-1 else '%05d.png' % (batch_no*img_batch.shape[0]+i)
72
+ save_name = save_name.replace('.png', '-%s.png'%suffix) if suffix else save_name
73
+ image.save(os.path.join(save_dir, save_name), 'PNG')
74
+ else:
75
+ #! multi-channel: save each channel as a single image
76
+ for i in range(N):
77
+ # [-1,1] >>> [0,255]
78
+ for j in range(C):
79
+ image = Image.fromarray((127.5*(img_batch[i,:,:,j]+1.)).astype(np.uint8))
80
+ if batch_no == -1:
81
+ _, file_name = os.path.split(filename_list[i])
82
+ name_only, _ = os.path.os.path.splitext(file_name)
83
+ save_name = name_only + '_c%d.png' % j
84
+ else:
85
+ save_name = '%05d_c%d.png' % (batch_no*N+i, j)
86
+ save_name = save_name.replace('.png', '-%s.png'%suffix) if suffix else save_name
87
+ image.save(os.path.join(save_dir, save_name), 'PNG')
88
+ return None
89
+
90
+
91
+ def save_normLabs_from_batch(img_batch, save_dir, filename_list, batch_no=-1, suffix=None):
92
+ N,H,W,C = img_batch.shape
93
+ if C != 3:
94
+ print('@Warning:the Lab images are NOT in 3 channels!')
95
+ return None
96
+ # denormalization: L: (L+1.0)*50.0 | a: a*110.0| b: b*110.0
97
+ img_batch[:,:,:,0] = img_batch[:,:,:,0] * 50.0 + 50.0
98
+ img_batch[:,:,:,1:3] = img_batch[:,:,:,1:3] * 110.0
99
+ #! convert into RGB color image
100
+ for i in range(N):
101
+ rgb_img = cv2.cvtColor(img_batch[i,:,:,:], cv2.COLOR_LAB2RGB)
102
+ image = Image.fromarray((rgb_img*255.0).astype(np.uint8))
103
+ save_name = filename_list[i] if batch_no==-1 else '%05d.png' % (batch_no*N+i)
104
+ save_name = save_name.replace('.png', '-%s.png'%suffix) if suffix else save_name
105
+ image.save(os.path.join(save_dir, save_name), 'PNG')
106
+ return None
107
+
108
+
109
+ def save_markedSP_from_batch(img_batch, spix_batch, save_dir, filename_list, batch_no=-1, suffix=None):
110
+ N,H,W,C = img_batch.shape
111
+ #! img_batch: BGR nd-array (range:0~1)
112
+ #! map_batch: single-channel spixel map
113
+ #print('----------', img_batch.shape, spix_batch.shape)
114
+ for i in range(N):
115
+ norm_image = img_batch[i,:,:,:]*0.5+0.5
116
+ spixel_bd_image = mark_boundaries(norm_image, spix_batch[i,:,:,0].astype(int), color=(1,1,1))
117
+ #spixel_bd_image = cv2.cvtColor(spixel_bd_image, cv2.COLOR_BGR2RGB)
118
+ image = Image.fromarray((spixel_bd_image*255.0).astype(np.uint8))
119
+ save_name = filename_list[i] if batch_no==-1 else '%05d.png' % (batch_no*N+i)
120
+ save_name = save_name.replace('.png', '-%s.png'%suffix) if suffix else save_name
121
+ image.save(os.path.join(save_dir, save_name), 'PNG')
122
+ return None
123
+
124
+
125
+ def get_filelist(data_dir):
126
+ file_list = glob.glob(os.path.join(data_dir, '*.*'))
127
+ file_list.sort()
128
+ return file_list
129
+
130
+
131
+ def collect_filenames(data_dir):
132
+ file_list = get_filelist(data_dir)
133
+ name_list = []
134
+ for file_path in file_list:
135
+ _, file_name = os.path.split(file_path)
136
+ name_list.append(file_name)
137
+ name_list.sort()
138
+ return name_list
139
+
140
+
141
+ def exists_or_mkdir(path, need_remove=False):
142
+ if not os.path.exists(path):
143
+ os.makedirs(path)
144
+ elif need_remove:
145
+ shutil.rmtree(path)
146
+ os.makedirs(path)
147
+ return None
148
+
149
+
150
+ def save_list(save_path, data_list, append_mode=False):
151
+ n = len(data_list)
152
+ if append_mode:
153
+ with open(save_path, 'a') as f:
154
+ f.writelines([str(data_list[i]) + '\n' for i in range(n-1,n)])
155
+ else:
156
+ with open(save_path, 'w') as f:
157
+ f.writelines([str(data_list[i]) + '\n' for i in range(n)])
158
+ return None
159
+
160
+
161
+ def save_dict(save_path, dict):
162
+ json.dumps(dict, open(save_path,"w"))
163
+ return None
164
+
165
+
166
+ if __name__ == '__main__':
167
+ data_dir = '../PolyNet/PolyNet/cache/'
168
+ #visualizeLossCurves(data_dir)
169
+ clbar = GamutIndex()
170
+ ab, ab_gamut_mask = clbar._get_gamut_mask()
171
+ ab2q = clbar._get_ab_to_q(ab_gamut_mask)
172
+ q2ab = clbar._get_q_to_ab(ab, ab_gamut_mask)
173
+ maps = ab_gamut_mask*255.0
174
+ image = Image.fromarray(maps.astype(np.uint8))
175
+ image.save('gamut.png', 'PNG')
176
+ print(ab2q.shape)
177
+ print(q2ab.shape)
178
+ print('label range:', np.min(ab2q), np.max(ab2q))