menghanxia commited on
Commit
302d824
1 Parent(s): b3640b9

modified app.py with checkpt downloading

Browse files
Files changed (1) hide show
  1. app.py +17 -41
app.py CHANGED
@@ -2,40 +2,12 @@ import gradio as gr
2
  import os, requests
3
  from inference import setup_model, colorize_grayscale, predict_anchors
4
 
5
- ## download checkpoint
6
- def download_file_from_google_drive(id, destination):
7
- def get_confirm_token(response):
8
- for key, value in response.cookies.items():
9
- if key.startswith('download_warning'):
10
- return value
11
- return None
12
-
13
- def save_response_content(response, destination):
14
- CHUNK_SIZE = 32768
15
- with open(destination, "wb") as f:
16
- for chunk in response.iter_content(CHUNK_SIZE):
17
- if chunk: # filter out keep-alive new chunks
18
- f.write(chunk)
19
-
20
- URL = "https://docs.google.com/uc?export=download"
21
- session = requests.Session()
22
- response = session.get(URL, params = { 'id' : id }, stream = True)
23
- token = get_confirm_token(response)
24
-
25
- if token:
26
- params = { 'id' : id, 'confirm' : token }
27
- response = session.get(URL, params = params, stream = True)
28
- save_response_content(response, destination)
29
-
30
- id = "1J4vB6kG4xBLUUKpXr5IhnSSa4maXgRvQ"
31
- destination = "disco-beta.pth.rar"
32
- download_file_from_google_drive(id, destination)
33
  os.rename("disco-beta.pth.tar", "./checkpoints/disco-beta.pth.tar")
34
 
35
  ## step 1: set up model
36
- device = "cuda"
37
- checkpt_path = "./checkpoints/disco-beta.pth.tar"
38
- assert os.path.exists(checkpt_path), "No checkpoint found!"
39
  colorizer, colorLabeler = setup_model(checkpt_path, device=device)
40
 
41
  def click_colorize(rgb_img, hint_img, n_anchors, is_high_res, is_editable):
@@ -55,9 +27,11 @@ def switch_states(is_checked):
55
  else:
56
  return gr.Image.update(visible=False), gr.Button.update(visible=False)
57
 
58
- demo = gr.Blocks(title="DISCO: Image Colorization")
59
  with demo:
60
- gr.Markdown(value="""**DISCO: image colorization that disentangles color multimodality and spatial affinity via global anchors**.""")
 
 
61
  with gr.Row():
62
  with gr.Column(scale=1):
63
  Image_input = gr.Image(type="numpy", label="Input", interactive=True)
@@ -78,15 +52,17 @@ with demo:
78
  Button_run.click(fn=click_colorize, inputs=[Image_input, Image_anchor, Num_anchor, Radio_resolution, Ckeckbox_editable], \
79
  outputs=Image_output)
80
  ## guiline
81
- gr.Markdown(value="""
82
- **Guideline**
83
- 1. Upload your image;
84
  2. Set up the arguments: "Num. of anchors" and "Colorization resolution";
85
- 3. Two modes are supported:
86
- - **Editable**: check ""Show editable anchors" and click "Predict anchors". Then, modify the colors of the predicted anchors (anchor mask will be applied afterward). Finally, click "Colorize" to get the result.
87
  - **Automatic**: click "Colorize" to get the automatically colorized output.
88
-
89
- *To know more about the method, please refer to our project page: [https://menghanxia.github.io/projects/disco.html](https://menghanxia.github.io/projects/disco.html)*
 
 
90
  """)
91
 
92
- demo.launch(server_name='9.134.253.83',server_port=7788)
 
 
2
  import os, requests
3
  from inference import setup_model, colorize_grayscale, predict_anchors
4
 
5
+ os.system("wget https://huggingface.co/menghanxia/disco/tree/main/disco-beta.pth.tar")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  os.rename("disco-beta.pth.tar", "./checkpoints/disco-beta.pth.tar")
7
 
8
  ## step 1: set up model
9
+ device = "cpu"
10
+ checkpt_path = "checkpoints/disco-beta.pth.rar"
 
11
  colorizer, colorLabeler = setup_model(checkpt_path, device=device)
12
 
13
  def click_colorize(rgb_img, hint_img, n_anchors, is_high_res, is_editable):
 
27
  else:
28
  return gr.Image.update(visible=False), gr.Button.update(visible=False)
29
 
30
+ demo = gr.Blocks(title="DISCO")
31
  with demo:
32
+ gr.Markdown(value="""
33
+ **Gradio demo for DISCO: Disentangled Image Colorization via Global Anchors. [Project Page](https://menghanxia.github.io/projects/disco.html)**.
34
+ """)
35
  with gr.Row():
36
  with gr.Column(scale=1):
37
  Image_input = gr.Image(type="numpy", label="Input", interactive=True)
 
52
  Button_run.click(fn=click_colorize, inputs=[Image_input, Image_anchor, Num_anchor, Radio_resolution, Ckeckbox_editable], \
53
  outputs=Image_output)
54
  ## guiline
55
+ gr.Markdown(value="""
56
+ **Usage Guideline**
57
+ 1. upload your image;
58
  2. Set up the arguments: "Num. of anchors" and "Colorization resolution";
59
+ 3. Run the colorization (two modes supported):
 
60
  - **Automatic**: click "Colorize" to get the automatically colorized output.
61
+ - **Editable**: check ""Show editable anchors" and click "Predict anchors". Then, modify the colors of the predicted anchors (only anchor region will be used). Finally, click "Colorize" to get the result.
62
+ """)
63
+ gr.HTML(value="""
64
+ <p style='text-align: center'><a href='https://menghanxia.github.io/projects/disco.html' target='_blank'>DISCO Project Page</a> | <a href='https://github.com/MenghanXia/DisentangledColorization' target='_blank'>Github Repo</a></p>
65
  """)
66
 
67
+ #demo.launch(server_name='9.134.253.83',server_port=7788)
68
+ demo.launch()