andreped commited on
Commit
4146752
1 Parent(s): 366c806

Added support to download prediction

Browse files
Files changed (2) hide show
  1. demo/src/gui.py +40 -15
  2. demo/src/inference.py +5 -1
demo/src/gui.py CHANGED
@@ -59,7 +59,8 @@ class WebUI:
59
  visible=True,
60
  elem_id="model-3d",
61
  camera_position=[90, 180, 768],
62
- ).style(height=512)
 
63
 
64
  def set_class_name(self, value):
65
  LOGGER.info(f"Changed task to: {value}")
@@ -75,22 +76,31 @@ class WebUI:
75
 
76
  def process(self, mesh_file_name):
77
  path = mesh_file_name.name
 
 
 
78
  run_model(
79
  path,
80
  model_path=os.path.join(self.cwd, "resources/models/"),
81
  task=self.class_names[self.class_name],
82
  name=self.result_names[self.class_name],
 
83
  )
84
  LOGGER.info("Converting prediction NIfTI to OBJ...")
85
- nifti_to_obj("prediction.nii.gz")
86
 
87
  LOGGER.info("Loading CT to numpy...")
88
  self.images = load_ct_to_numpy(path)
89
 
90
  LOGGER.info("Loading prediction volume to numpy..")
91
- self.pred_images = load_pred_volume_to_numpy("./prediction.nii.gz")
92
 
93
  return "./prediction.obj"
 
 
 
 
 
94
 
95
  def get_img_pred_pair(self, k):
96
  k = int(k)
@@ -98,7 +108,6 @@ class WebUI:
98
  self.combine_ct_and_seg(self.images[k], self.pred_images[k]),
99
  visible=True,
100
  elem_id="model-2d",
101
- ).style(
102
  color_map={self.class_name: "#ffae00"},
103
  height=512,
104
  width=512,
@@ -122,9 +131,7 @@ class WebUI:
122
  autoscroll=True,
123
  elem_id="logs",
124
  show_copy_button=True,
125
- scroll_to_output=False,
126
  container=True,
127
- line_breaks=True,
128
  )
129
  demo.load(read_logs, None, logs, every=1)
130
 
@@ -160,7 +167,7 @@ class WebUI:
160
  label="Task",
161
  info="Which structure to segment.",
162
  multiselect=False,
163
- size="sm",
164
  )
165
  model_selector.input(
166
  fn=lambda x: self.set_class_name(x),
@@ -173,16 +180,33 @@ class WebUI:
173
  "Run analysis",
174
  variant="primary",
175
  elem_id="run-button",
176
- ).style(
177
- full_width=False,
178
  size="lg",
179
  )
 
 
 
 
180
  run_btn.click(
181
  fn=lambda x: self.process(x),
182
  inputs=file_output,
183
  outputs=self.volume_renderer,
184
  )
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  with gr.Row():
187
  gr.Examples(
188
  examples=[
@@ -202,16 +226,17 @@ class WebUI:
202
  )
203
 
204
  with gr.Row():
205
- with gr.Box():
206
  with gr.Column():
207
  # create dummy image to be replaced by loaded images
208
  t = gr.AnnotatedImage(
209
  visible=True, elem_id="model-2d"
210
- ).style(
211
- color_map={self.class_name: "#ffae00"},
212
- height=512,
213
- width=512,
214
  )
 
 
 
 
 
215
 
216
  self.slider.input(
217
  self.get_img_pred_pair,
@@ -221,7 +246,7 @@ class WebUI:
221
 
222
  self.slider.render()
223
 
224
- with gr.Box():
225
  self.volume_renderer.render()
226
 
227
  # sharing app publicly -> share=True:
 
59
  visible=True,
60
  elem_id="model-3d",
61
  camera_position=[90, 180, 768],
62
+ height=512,
63
+ )
64
 
65
  def set_class_name(self, value):
66
  LOGGER.info(f"Changed task to: {value}")
 
76
 
77
  def process(self, mesh_file_name):
78
  path = mesh_file_name.name
79
+ curr = path.split("/")[-1]
80
+ self.extension = ".".join(curr.split(".")[1:])
81
+ self.filename = curr.split(".")[0] + "-" + self.class_names[self.class_name]
82
  run_model(
83
  path,
84
  model_path=os.path.join(self.cwd, "resources/models/"),
85
  task=self.class_names[self.class_name],
86
  name=self.result_names[self.class_name],
87
+ output_filename=self.filename + "." + self.extension
88
  )
89
  LOGGER.info("Converting prediction NIfTI to OBJ...")
90
+ nifti_to_obj(path=self.filename + "." + self.extension)
91
 
92
  LOGGER.info("Loading CT to numpy...")
93
  self.images = load_ct_to_numpy(path)
94
 
95
  LOGGER.info("Loading prediction volume to numpy..")
96
+ self.pred_images = load_pred_volume_to_numpy(self.filename + "." + self.extension)
97
 
98
  return "./prediction.obj"
99
+
100
+ def download_prediction(self):
101
+ if (not self.filename) or (not self.extension):
102
+ LOGGER.error("The prediction is not available or ready to download. Wait until the result is available in the 3D viewer.")
103
+ return self.filename + "." + self.extension
104
 
105
  def get_img_pred_pair(self, k):
106
  k = int(k)
 
108
  self.combine_ct_and_seg(self.images[k], self.pred_images[k]),
109
  visible=True,
110
  elem_id="model-2d",
 
111
  color_map={self.class_name: "#ffae00"},
112
  height=512,
113
  width=512,
 
131
  autoscroll=True,
132
  elem_id="logs",
133
  show_copy_button=True,
 
134
  container=True,
 
135
  )
136
  demo.load(read_logs, None, logs, every=1)
137
 
 
167
  label="Task",
168
  info="Which structure to segment.",
169
  multiselect=False,
170
+ scale=1.0,
171
  )
172
  model_selector.input(
173
  fn=lambda x: self.set_class_name(x),
 
180
  "Run analysis",
181
  variant="primary",
182
  elem_id="run-button",
183
+ #scale=1.0,
184
+ # size=1.0,
185
  size="lg",
186
  )
187
+ #.style(
188
+ # full_width=False,
189
+ # size="lg",
190
+ #)
191
  run_btn.click(
192
  fn=lambda x: self.process(x),
193
  inputs=file_output,
194
  outputs=self.volume_renderer,
195
  )
196
 
197
+ download_btn = gr.DownloadButton(
198
+ "Download the result as NIfTI",
199
+ visible=True,
200
+ variant="secondary",
201
+ # scale=1.0,
202
+ size="sm",
203
+ )
204
+ download_btn.click(
205
+ fn=self.download_prediction,
206
+ inputs=None,
207
+ outputs=download_btn,
208
+ )
209
+
210
  with gr.Row():
211
  gr.Examples(
212
  examples=[
 
226
  )
227
 
228
  with gr.Row():
229
+ with gr.Group(): #gr.Box():
230
  with gr.Column():
231
  # create dummy image to be replaced by loaded images
232
  t = gr.AnnotatedImage(
233
  visible=True, elem_id="model-2d"
 
 
 
 
234
  )
235
+ #.style(
236
+ # color_map={self.class_name: "#ffae00"},
237
+ # height=512,
238
+ # width=512,
239
+ #)
240
 
241
  self.slider.input(
242
  self.get_img_pred_pair,
 
246
 
247
  self.slider.render()
248
 
249
+ with gr.Group(): #gr.Box():
250
  self.volume_renderer.render()
251
 
252
  # sharing app publicly -> share=True:
demo/src/inference.py CHANGED
@@ -11,6 +11,7 @@ def run_model(
11
  verbose: str = "info",
12
  task: str = "CT_Airways",
13
  name: str = "Airways",
 
14
  ):
15
  if verbose == "debug":
16
  logging.getLogger().setLevel(logging.DEBUG)
@@ -27,6 +28,9 @@ def run_model(
27
  if os.path.exists("./result/"):
28
  shutil.rmtree("./result/")
29
 
 
 
 
30
  patient_directory = ""
31
  output_path = ""
32
  try:
@@ -84,7 +88,7 @@ def run_model(
84
  + "-t1gd_annotation-"
85
  + name
86
  + ".nii.gz",
87
- "./prediction.nii.gz",
88
  )
89
  # Clean-up
90
  if os.path.exists(patient_directory):
 
11
  verbose: str = "info",
12
  task: str = "CT_Airways",
13
  name: str = "Airways",
14
+ output_filename: str = None,
15
  ):
16
  if verbose == "debug":
17
  logging.getLogger().setLevel(logging.DEBUG)
 
28
  if os.path.exists("./result/"):
29
  shutil.rmtree("./result/")
30
 
31
+ if output_filename is None:
32
+ raise ValueError("Please, set output_filename.")
33
+
34
  patient_directory = ""
35
  output_path = ""
36
  try:
 
88
  + "-t1gd_annotation-"
89
  + name
90
  + ".nii.gz",
91
+ output_filename,
92
  )
93
  # Clean-up
94
  if os.path.exists(patient_directory):