mertkarabacak commited on
Commit
5629ddc
1 Parent(s): 22ef63b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -152
app.py CHANGED
@@ -9,7 +9,7 @@ from math import sqrt
9
  from scipy import stats as st
10
  from matplotlib import pyplot as plt
11
 
12
- from sklearn.linear_model import LogisticRegression
13
 
14
  import shap
15
  import gradio as gr
@@ -21,24 +21,39 @@ from datasets import load_dataset
21
 
22
  #Read data training data.
23
 
24
- x1 = pd.read_csv("6m_data_train.csv", index_col = 0, low_memory = False)
 
 
25
 
26
- x2 = pd.read_csv("12m_data_train.csv", index_col = 0, low_memory = False)
 
 
27
 
28
- x3 = pd.read_csv("24m_data_train.csv", index_col = 0, low_memory = False)
29
-
30
- x4 = pd.read_csv("36m_data_train.csv", index_col = 0, low_memory = False)
31
 
 
 
 
32
 
33
  #Read validation data.
34
 
35
- x1_valid = pd.read_csv("6m_data_valid.csv", index_col = 0, low_memory = False)
 
 
36
 
37
- x2_valid = pd.read_csv("12m_data_valid.csv", index_col = 0, low_memory = False)
 
 
38
 
39
- x3_valid = pd.read_csv("24m_data_valid.csv", index_col = 0, low_memory = False)
 
 
40
 
41
- x4_valid = pd.read_csv("36m_data_valid.csv", index_col = 0, low_memory = False)
 
 
42
 
43
 
44
  #Define feature names.
@@ -59,82 +74,67 @@ f4_names = [f4.replace('__', ' - ') for f4 in f4_names]
59
  f4_names = [f4.replace('_', ' ') for f4 in f4_names]
60
 
61
 
62
- #Prepare training data for the outcome 1 (prolonged LOS).
63
  y1 = x1.pop('OUTCOME')
64
 
65
- #Prepare validation data for the outcome 1 (prolonged LOS).
66
- y1_valid = x1_valid.pop('OUTCOME')
67
-
68
- #Prepare training data for the outcome 2 (non-home discharges).
69
  y2 = x2.pop('OUTCOME')
70
 
71
- #Prepare validation data for the outcome 2 (non-home discharges).
72
- y2_valid = x2_valid.pop('OUTCOME')
73
-
74
- #Prepare training data for the outcome 3 (30-day readmissions).
75
  y3 = x3.pop('OUTCOME')
76
 
77
- #Prepare validation data for the outcome 3 (30-day readmissions).
78
- y3_valid = x3_valid.pop('OUTCOME')
79
-
80
- #Prepare training data for the outcome 4 (unplanned reoperations).
81
  y4 = x4.pop('OUTCOME')
82
 
83
- #Prepare validation data for the outcome 4 (unplanned reoperations).
84
- y4_valid = x4_valid.pop('OUTCOME')
85
-
86
 
87
- #Assign hyperparameters.
88
 
89
- y1_params = {'objective': 'binary', 'boosting_type': 'gbdt', 'lambda_l1': 2.874728678068222e-05, 'lambda_l2': 0.002100238688192627, 'num_leaves': 39, 'feature_fraction': 0.4504130718946593, 'bagging_fraction': 0.8916461477863318, 'bagging_freq': 7, 'min_child_samples': 45, 'metric': 'binary_logloss', 'verbosity': -1, 'random_state': 31}
90
- y2_params = {'objective': 'binary', 'boosting_type': 'gbdt', 'lambda_l1': 0.0002837317278662907, 'lambda_l2': 5.412618023120056e-06, 'num_leaves': 78, 'feature_fraction': 0.4044321534682025, 'bagging_fraction': 0.747678020066352, 'bagging_freq': 6, 'min_child_samples': 44, 'metric': 'binary_logloss', 'verbosity': -1, 'random_state': 31}
91
- y3_params = {'objective': 'binary', 'boosting_type': 'gbdt', 'lambda_l1': 0.00016354134178989566, 'lambda_l2': 0.005110516449291205, 'num_leaves': 4, 'feature_fraction': 0.525789668995701, 'bagging_fraction': 0.4203858842031528, 'bagging_freq': 3, 'min_child_samples': 66, 'metric': 'binary_logloss', 'verbosity': -1, 'random_state': 31}
92
- y4_params = {'objective': 'binary', 'boosting_type': 'gbdt', 'lambda_l1': 0.00014329772210712767, 'lambda_l2': 0.001638738946438707, 'num_leaves': 2, 'feature_fraction': 0.565882308738563, 'bagging_fraction': 0.47701769327658605, 'bagging_freq': 5, 'min_child_samples': 59, 'metric': 'binary_logloss', 'verbosity': -1, 'random_state': 31}
93
 
94
- #Training models.
 
95
 
96
- from lightgbm import LGBMClassifier
97
- lgb = LGBMClassifier(**y1_params)
98
- y1_model = lgb
99
 
100
- y1_model = y1_model.fit(x1, y1)
101
  y1_explainer = shap.Explainer(y1_model.predict, x1)
102
- y1_calib_probs = y1_model.predict_proba(x1_valid)
103
- y1_calib_model = LogisticRegression()
104
- y1_calib_model = y1_calib_model.fit(y1_calib_probs, y1_valid)
105
 
106
 
107
- from lightgbm import LGBMClassifier
108
- lgb = LGBMClassifier(**y2_params)
109
- y2_model = lgb
 
 
 
 
 
110
 
111
- y2_model = y2_model.fit(x2, y2)
112
  y2_explainer = shap.Explainer(y2_model.predict, x2)
113
- y2_calib_probs = y2_model.predict_proba(x2_valid)
114
- y2_calib_model = LogisticRegression()
115
- y2_calib_model = y2_calib_model.fit(y2_calib_probs, y2_valid)
116
 
117
 
118
- from lightgbm import LGBMClassifier
119
- lgb = LGBMClassifier(**y3_params)
120
- y3_model = lgb
 
 
 
 
 
121
 
122
- y3_model = y3_model.fit(x3, y3)
123
  y3_explainer = shap.Explainer(y3_model.predict, x3)
124
- y3_calib_probs = y3_model.predict_proba(x3_valid)
125
- y3_calib_model = LogisticRegression()
126
- y3_calib_model = y3_calib_model.fit(y3_calib_probs, y3_valid)
127
 
128
 
129
- from lightgbm import LGBMClassifier
130
- lgb = LGBMClassifier(**y4_params)
131
- y4_model = lgb
 
 
 
 
 
132
 
133
- y4_model = y4_model.fit(x4, y4)
134
  y4_explainer = shap.Explainer(y4_model.predict, x4)
135
- y4_calib_probs = y4_model.predict_proba(x4_valid)
136
- y4_calib_model = LogisticRegression()
137
- y4_calib_model = y4_calib_model.fit(y4_calib_probs, y4_valid)
138
 
139
 
140
  output_y1 = (
@@ -158,7 +158,7 @@ output_y2 = (
158
  output_y3 = (
159
  """
160
  <br/>
161
- <center>The probability of 24-month survival:</center>
162
  <br/>
163
  <center><h1>{:.2f}%</h1></center>
164
  """
@@ -167,7 +167,7 @@ output_y3 = (
167
  output_y4 = (
168
  """
169
  <br/>
170
- <center>The probability of 36-month survival:</center>
171
  <br/>
172
  <center><h1>{:.2f}%</h1></center>
173
  """
@@ -177,8 +177,7 @@ output_y4 = (
177
  #Define predict for y1.
178
  def y1_predict(*args):
179
  df1 = pd.DataFrame([args], columns=x1.columns)
180
- pos_pred = y1_model.predict_proba(df1)
181
- pos_pred = y1_calib_model.predict_proba(pos_pred)
182
  prob = pos_pred[0][1]
183
  prob = 1-prob
184
  output = output_y1.format(prob * 100)
@@ -187,8 +186,7 @@ def y1_predict(*args):
187
  #Define predict for y2.
188
  def y2_predict(*args):
189
  df2 = pd.DataFrame([args], columns=x2.columns)
190
- pos_pred = y2_model.predict_proba(df2)
191
- pos_pred = y2_calib_model.predict_proba(pos_pred)
192
  prob = pos_pred[0][1]
193
  prob = 1-prob
194
  output = output_y2.format(prob * 100)
@@ -197,8 +195,7 @@ def y2_predict(*args):
197
  #Define predict for y3.
198
  def y3_predict(*args):
199
  df3 = pd.DataFrame([args], columns=x3.columns)
200
- pos_pred = y3_model.predict_proba(df3)
201
- pos_pred = y3_calib_model.predict_proba(pos_pred)
202
  prob = pos_pred[0][1]
203
  prob = 1-prob
204
  output = output_y3.format(prob * 100)
@@ -207,10 +204,9 @@ def y3_predict(*args):
207
  #Define predict for y4.
208
  def y4_predict(*args):
209
  df4 = pd.DataFrame([args], columns=x4.columns)
210
- pos_pred = y4_model.predict_proba(df4)
211
- pos_pred = y4_calib_model.predict_proba(pos_pred)
212
  prob = pos_pred[0][1]
213
- prob = 1-prob
214
  output = output_y4.format(prob * 100)
215
  return output
216
 
@@ -297,14 +293,14 @@ def y4_interpret(*args):
297
  return fig
298
 
299
 
300
- with gr.Blocks(title = "NCDB-GBM") as demo:
301
 
302
  gr.Markdown(
303
  """
304
  <br/>
305
  <center><h2>NOT FOR CLINICAL USE</h2><center>
306
  <br/>
307
- <center><h1>GBM Survival Outcomes</h1></center>
308
  <center><h2>Prediction Tool</h2></center>
309
  <br/>
310
  <center><h3>This web application should not be used to guide any clinical decisions.</h3><center>
@@ -330,44 +326,44 @@ with gr.Blocks(title = "NCDB-GBM") as demo:
330
  </tr>
331
  <tr>
332
  <td>6-Month Mortality</td>
333
- <td>LightGBM</td>
334
- <td>0.694 (0.686 - 0.702)</td>
335
- <td>0.810 (0.803 - 0.817)</td>
336
- <td>0.772 (0.765 - 0.779)</td>
337
- <td>0.719 (0.711 - 0.727)</td>
338
- <td>0.831 (0.824 - 0.838)</td>
339
- <td>0.152 (0.146 - 0.158)</td>
340
  </tr>
341
  <tr>
342
  <td>12-Month Mortality</td>
343
- <td>LightGBM</td>
344
- <td>0.700 (0.692 - 0.708)</td>
345
- <td>0.742 (0.735 - 0.749)</td>
346
- <td>0.720 (0.712 - 0.728)</td>
347
- <td>0.821 (0.815 - 0.827)</td>
348
- <td>0.808 (0.792 - 0.807)</td>
349
- <td>0.183 (0.176 - 0.190)</td>
350
  </tr>
351
  <tr>
352
- <td>24-Month Mortality</td>
353
- <td>LightGBM</td>
354
- <td>0.742 (0.735 - 0.749)</td>
355
- <td>0.555 (0.547 - 0.563)</td>
356
- <td>0.702 (0.694 - 0.710)</td>
357
- <td>0.897 (0.892 - 0.902)</td>
358
- <td>0.716 (0.706 - 0.727)</td>
359
- <td>0.153 (0.147 - 0.159)</td>
360
- </tr>
361
  <tr>
362
- <td>36-Month Mortality</td>
363
- <td>LightGBM</td>
364
- <td>0.705 (0.697 - 0.713)</td>
365
- <td>0.576 (0.568 - 0.584)</td>
366
- <td>0.689 (0.681 - 0.697)</td>
367
- <td>0.937 (0.933 - 0.941)</td>
368
- <td>0.707 (0.687 - 0.713)</td>
369
- <td>0.103 (0.098 - 0.108)</td>
370
- </tr>
371
  </table>
372
  </div>
373
  """
@@ -381,45 +377,29 @@ with gr.Blocks(title = "NCDB-GBM") as demo:
381
 
382
  Sex = gr.Dropdown(label = "Sex", choices = ['Male', 'Female'], type = 'index', value = 'Male')
383
 
384
- Race = gr.Dropdown(label = "Race", choices = ['White', 'Black', 'Asian Indian or Pakistani', 'Chinese', 'Filipino', 'American Indian, Aleutian, or Eskimo', 'Vietnamese', 'Korean', 'Other or Unknown'], type = 'index', value = 'White')
385
 
386
  Hispanic_Ethnicity = gr.Dropdown(label = "Hispanic Ethnicity", choices = ['No', 'Yes', 'Unknown'], type = 'index', value = 'No')
387
 
388
- Primary_Payor = gr.Dropdown(label = "Primary Payor", choices = ['Private insurance', 'Medicare', 'Medicaid', 'Other government', 'Not insured', 'Unknown'], type = 'index', value = 'Private insurance')
389
 
390
- Facility_Type = gr.Dropdown(label = "Facility Type", choices = ['Academic/Research Program', 'Comprehensive Community Cancer Program', 'Integrated Network Cancer Program', 'Community Cancer Program', 'Other or Unknown'], type = 'index', value = 'Academic/Research Program')
391
 
392
- Facility_Location = gr.Dropdown(label = "Facility Location", choices = ['South Atlantic', 'East North Central', 'Middle Atlantic', 'East North Central', 'Middle Atlantic', 'Pacific', 'West South Central', 'West North Central', 'East South Central', 'New England', 'Mountain', 'Unknown or Other'], type = 'index', value = 'South Atlantic')
393
 
394
- CharlsonDeyo_Score = gr.Dropdown(label = "Charlson-Deyo Score", choices = ['0', '1', '2', 'Greater than 3'], type = 'index', value = '0')
395
-
396
- Karnofsky_Performance_Scale = gr.Dropdown(label = "Karnofsky Performance Scale", choices = ['KPS 0-20', 'KPS 21-40', 'KPS 41-60', 'KPS 61-80', 'KPS 81-100', 'Unknown'], type = 'index', value = 'KPS 81-100')
397
-
398
- Laterality = gr.Dropdown(label = "Laterality", choices = ['Right', 'Left', 'Bilateral', 'Midline', 'Unknown'], type = 'index', value = 'Right')
399
-
400
- Tumor_Localization = gr.Dropdown(label = "Tumor Localization", choices = ['Frontal lobe', 'Temporal lobe', 'Parietal lobe', 'Occipital lobe', 'Overlapping', 'Intraventricular', 'Cerebellum', 'Brain stem', 'Unknown'], type = 'index', value = 'Frontal lobe')
401
-
402
- Focality = gr.Dropdown(label = "Focality", choices = ['Unifocal', 'Multifocal', 'Unknown'], type = 'index', value = 'Unifocal')
403
-
404
- Diagnostic_Biopsy = gr.Dropdown(label = "Diagnostic Biopsy", choices = ['No', 'Yes', 'Unknown'], type = 'index', value = 'No')
405
-
406
- Tumor_Size = gr.Dropdown(label = "Tumor Size", choices = ['< 2 cm', '2 - 3.9 cm', '4 - 5.9 cm', '6 - 7.9 cm', '8 - 9.9 cm', '10 - 11.9 cm', '12 - 13.9 cm', '14 - 15.9 cm', '16 - 17.9 cm', '18 - 19.9 cm', '> 20 cm', 'Unknown'], type = 'index', value = '< 2 cm')
407
 
408
- CoDeletion_1p19q = gr.Dropdown(label = "1p19q Co-Deletion", choices = ['No', 'Yes', 'Unknown'], type = 'index', value = 'No')
409
 
410
- MGMT_Methylation = gr.Dropdown(label = "MGMT Methylation", choices = ['Unmethylated', 'Methylated', 'Unknown'], type = 'index', value = 'Unmethylated')
411
-
412
- Ki67_Labeling_Index = gr.Dropdown(label = 'Ki-67 Labeling Index', choices = ['0-20%', '21-40%', '41-60%', '61-80%', '81-100%', 'Normal (no percentage available)', 'Slightly elevated (no percentage available)', 'Elevated (no percentage available)', 'Unknown'], type = 'index', value = '0-20%')
413
 
414
- Resective_Surgery = gr.Dropdown(label = "Resective Surgery", choices = ['No', 'Yes', 'Unknown'], type = 'index', value = 'Yes')
415
 
416
- Extent_of_Resection = gr.Dropdown(label = "Extent of Resection", choices = ['No resective surgery was performed', 'Gross total resection', 'Subtotal resection', 'Unknown'], type = 'index', value = 'Gross total resection')
417
 
418
- Radiation_Treatment = gr.Dropdown(label = "Radiation Treatment", choices = ['No', 'Yes', 'Unknown'], type = 'index', value = 'Yes')
419
 
420
- Chemotherapy = gr.Dropdown(label = "Chemotherapy", choices = ['No', 'Yes (single-agent chemotherapy)', 'Yes (multi-agent chemotherapy)', 'Yes (details unknown)', 'Unknown'], type = 'index', value = 'No')
421
-
422
- Immunotherapy = gr.Dropdown(label = "Immunotherapy", choices = ['No', 'Yes', 'Unknown'], type = 'index', value = 'No')
423
 
424
  with gr.Column():
425
 
@@ -429,7 +409,7 @@ with gr.Blocks(title = "NCDB-GBM") as demo:
429
  """
430
  <center> <h2>6-Month Survival</h2> </center>
431
  <br/>
432
- <center> This model uses the LightGBM algorithm.</center>
433
  <br/>
434
  """
435
  )
@@ -473,7 +453,7 @@ with gr.Blocks(title = "NCDB-GBM") as demo:
473
  """
474
  <center> <h2>12-Month Survival</h2> </center>
475
  <br/>
476
- <center> This model uses the LightGBM algorithm.</center>
477
  <br/>
478
  """
479
  )
@@ -516,9 +496,9 @@ with gr.Blocks(title = "NCDB-GBM") as demo:
516
 
517
  gr.Markdown(
518
  """
519
- <center> <h2>24-Month Survival</h2> </center>
520
  <br/>
521
- <center> This model uses the LightGBM algorithm.</center>
522
  <br/>
523
  """
524
  )
@@ -561,9 +541,9 @@ with gr.Blocks(title = "NCDB-GBM") as demo:
561
 
562
  gr.Markdown(
563
  """
564
- <center> <h2>36-Month Survival</h2> </center>
565
  <br/>
566
- <center> This model uses the LightGBM algorithm.</center>
567
  <br/>
568
  """
569
  )
@@ -600,56 +580,55 @@ with gr.Blocks(title = "NCDB-GBM") as demo:
600
  """
601
  <br/>
602
  """
603
- )
604
-
605
 
606
  y1_predict_btn.click(
607
  y1_predict,
608
- inputs = [Facility_Type,Facility_Location,Age,Sex,Race,Hispanic_Ethnicity,Primary_Payor,CharlsonDeyo_Score,Tumor_Localization,Laterality,Diagnostic_Biopsy,Ki67_Labeling_Index,Karnofsky_Performance_Scale,MGMT_Methylation,Focality,Tumor_Size,Chemotherapy,Immunotherapy,CoDeletion_1p19q,Resective_Surgery,Extent_of_Resection,Radiation_Treatment],
609
  outputs = [label1]
610
  )
611
 
612
  y2_predict_btn.click(
613
  y2_predict,
614
- inputs = [Facility_Type,Facility_Location,Age,Sex,Race,Hispanic_Ethnicity,Primary_Payor,CharlsonDeyo_Score,Tumor_Localization,Laterality,Diagnostic_Biopsy,Ki67_Labeling_Index,Karnofsky_Performance_Scale,MGMT_Methylation,Focality,Tumor_Size,Chemotherapy,Immunotherapy,CoDeletion_1p19q,Resective_Surgery,Extent_of_Resection,Radiation_Treatment],
615
  outputs = [label2]
616
  )
617
 
618
  y3_predict_btn.click(
619
  y3_predict,
620
- inputs = [Facility_Type,Facility_Location,Age,Sex,Race,Hispanic_Ethnicity,Primary_Payor,CharlsonDeyo_Score,Tumor_Localization,Laterality,Diagnostic_Biopsy,Ki67_Labeling_Index,Karnofsky_Performance_Scale,MGMT_Methylation,Focality,Tumor_Size,Chemotherapy,Immunotherapy,CoDeletion_1p19q,Resective_Surgery,Extent_of_Resection,Radiation_Treatment],
621
  outputs = [label3]
622
  )
623
-
624
  y4_predict_btn.click(
625
  y4_predict,
626
- inputs = [Facility_Type,Facility_Location,Age,Sex,Race,Hispanic_Ethnicity,Primary_Payor,CharlsonDeyo_Score,Tumor_Localization,Laterality,Diagnostic_Biopsy,Ki67_Labeling_Index,Karnofsky_Performance_Scale,MGMT_Methylation,Focality,Tumor_Size,Chemotherapy,Immunotherapy,CoDeletion_1p19q,Resective_Surgery,Extent_of_Resection,Radiation_Treatment],
627
  outputs = [label4]
628
- )
629
 
630
  y1_interpret_btn.click(
631
  y1_interpret,
632
- inputs = [Facility_Type,Facility_Location,Age,Sex,Race,Hispanic_Ethnicity,Primary_Payor,CharlsonDeyo_Score,Tumor_Localization,Laterality,Diagnostic_Biopsy,Ki67_Labeling_Index,Karnofsky_Performance_Scale,MGMT_Methylation,Focality,Tumor_Size,Chemotherapy,Immunotherapy,CoDeletion_1p19q,Resective_Surgery,Extent_of_Resection,Radiation_Treatment],
633
  outputs = [plot1],
634
  )
635
 
636
  y2_interpret_btn.click(
637
  y2_interpret,
638
- inputs = [Facility_Type,Facility_Location,Age,Sex,Race,Hispanic_Ethnicity,Primary_Payor,CharlsonDeyo_Score,Tumor_Localization,Laterality,Diagnostic_Biopsy,Ki67_Labeling_Index,Karnofsky_Performance_Scale,MGMT_Methylation,Focality,Tumor_Size,Chemotherapy,Immunotherapy,CoDeletion_1p19q,Resective_Surgery,Extent_of_Resection,Radiation_Treatment],
639
  outputs = [plot2],
640
  )
641
 
642
  y3_interpret_btn.click(
643
  y3_interpret,
644
- inputs = [Facility_Type,Facility_Location,Age,Sex,Race,Hispanic_Ethnicity,Primary_Payor,CharlsonDeyo_Score,Tumor_Localization,Laterality,Diagnostic_Biopsy,Ki67_Labeling_Index,Karnofsky_Performance_Scale,MGMT_Methylation,Focality,Tumor_Size,Chemotherapy,Immunotherapy,CoDeletion_1p19q,Resective_Surgery,Extent_of_Resection,Radiation_Treatment],
645
  outputs = [plot3],
646
  )
647
-
648
  y4_interpret_btn.click(
649
  y4_interpret,
650
- inputs = [Facility_Type,Facility_Location,Age,Sex,Race,Hispanic_Ethnicity,Primary_Payor,CharlsonDeyo_Score,Tumor_Localization,Laterality,Diagnostic_Biopsy,Ki67_Labeling_Index,Karnofsky_Performance_Scale,MGMT_Methylation,Focality,Tumor_Size,Chemotherapy,Immunotherapy,CoDeletion_1p19q,Resective_Surgery,Extent_of_Resection,Radiation_Treatment],
651
  outputs = [plot4],
652
- )
653
 
654
  gr.Markdown(
655
  """
 
9
  from scipy import stats as st
10
  from matplotlib import pyplot as plt
11
 
12
+ from sklearn.calibration import CalibratedClassifierCV
13
 
14
  import shap
15
  import gradio as gr
 
21
 
22
  #Read data training data.
23
 
24
+ x1 = load_dataset("mertkarabacak/NCDB-GBM", data_files="6m_data_train.csv", use_auth_token = HF_TOKEN)
25
+ x1 = pd.DataFrame(x1['train'])
26
+ x1 = x1.iloc[:, 1:]
27
 
28
+ x2 = load_dataset("mertkarabacak/NCDB-GBM", data_files="12m_data_train.csv", use_auth_token = HF_TOKEN)
29
+ x2 = pd.DataFrame(x2['train'])
30
+ x2 = x2.iloc[:, 1:]
31
 
32
+ x3 = load_dataset("mertkarabacak/NCDB-GBM", data_files="18m_data_train.csv", use_auth_token = HF_TOKEN)
33
+ x3 = pd.DataFrame(x3['train'])
34
+ x3 = x3.iloc[:, 1:]
35
 
36
+ x4 = load_dataset("mertkarabacak/NCDB-GBM", data_files="24m_data_train.csv", use_auth_token = HF_TOKEN)
37
+ x4 = pd.DataFrame(x4['train'])
38
+ x4 = x4.iloc[:, 1:]
39
 
40
  #Read validation data.
41
 
42
+ x1_valid = load_dataset("mertkarabacak/NCDB-GBM", data_files="6m_data_valid.csv", use_auth_token = HF_TOKEN)
43
+ x1_valid = pd.DataFrame(x1_valid['train'])
44
+ x1_valid = x1_valid.iloc[:, 1:]
45
 
46
+ x2_valid = load_dataset("mertkarabacak/NCDB-GBM", data_files="12m_data_valid.csv", use_auth_token = HF_TOKEN)
47
+ x2_valid = pd.DataFrame(x2_valid['train'])
48
+ x2_valid = x2_valid.iloc[:, 1:]
49
 
50
+ x3_valid = load_dataset("mertkarabacak/NCDB-GBM", data_files="18m_data_valid.csv", use_auth_token = HF_TOKEN)
51
+ x3_valid = pd.DataFrame(x3_valid['train'])
52
+ x3_valid = x3_valid.iloc[:, 1:]
53
 
54
+ x4_valid = load_dataset("mertkarabacak/NCDB-GBM", data_files="24m_data_valid.csv", use_auth_token = HF_TOKEN)
55
+ x4_valid = pd.DataFrame(x4_valid['train'])
56
+ x4_valid = x4_valid.iloc[:, 1:]
57
 
58
 
59
  #Define feature names.
 
74
  f4_names = [f4.replace('_', ' ') for f4 in f4_names]
75
 
76
 
77
+ #Prepare training data for the outcome 1.
78
  y1 = x1.pop('OUTCOME')
79
 
80
+ #Prepare training data for the outcome 2.
 
 
 
81
  y2 = x2.pop('OUTCOME')
82
 
83
+ #Prepare training data for the outcome 3.
 
 
 
84
  y3 = x3.pop('OUTCOME')
85
 
86
+ #Prepare training data for the outcome 3.
 
 
 
87
  y4 = x4.pop('OUTCOME')
88
 
 
 
 
89
 
90
+ #Training models.
91
 
92
+ from tabpfn import TabPFNClassifier
93
+ tabpfn = TabPFNClassifier(device='cuda', N_ensemble_configurations=1)
 
 
94
 
95
+ y1_model = tabpfn
96
+ y1_model = y1_model.fit(x1, y1, overwrite_warning=True)
97
 
98
+ y1_calib_model = CalibratedClassifierCV(y1_model, method='sigmoid', cv='prefit')
99
+ y1_calib_model = y1_calib_model.fit(x1, y1)
 
100
 
 
101
  y1_explainer = shap.Explainer(y1_model.predict, x1)
 
 
 
102
 
103
 
104
+ from tabpfn import TabPFNClassifier
105
+ tabpfn = TabPFNClassifier(device='cuda', N_ensemble_configurations=1)
106
+
107
+ y2_model = tabpfn
108
+ y2_model = y2_model.fit(x2, y2, overwrite_warning=True)
109
+
110
+ y2_calib_model = CalibratedClassifierCV(y2_model, method='sigmoid', cv='prefit')
111
+ y2_calib_model = y2_calib_model.fit(x2, y2)
112
 
 
113
  y2_explainer = shap.Explainer(y2_model.predict, x2)
 
 
 
114
 
115
 
116
+ from tabpfn import TabPFNClassifier
117
+ tabpfn = TabPFNClassifier(device='cuda', N_ensemble_configurations=1)
118
+
119
+ y3_model = tabpfn
120
+ y3_model = y3_model.fit(x3, y3, overwrite_warning=True)
121
+
122
+ y3_calib_model = CalibratedClassifierCV(y3_model, method='sigmoid', cv='prefit')
123
+ y3_calib_model = y3_calib_model.fit(x3, y3)
124
 
 
125
  y3_explainer = shap.Explainer(y3_model.predict, x3)
 
 
 
126
 
127
 
128
+ from tabpfn import TabPFNClassifier
129
+ tabpfn = TabPFNClassifier(device='cuda', N_ensemble_configurations=1)
130
+
131
+ y4_model = tabpfn
132
+ y4_model = y4_model.fit(x4, y4, overwrite_warning=True)
133
+
134
+ y4_calib_model = CalibratedClassifierCV(y4_model, method='sigmoid', cv='prefit')
135
+ y4_calib_model = y4_calib_model.fit(x4, y4)
136
 
 
137
  y4_explainer = shap.Explainer(y4_model.predict, x4)
 
 
 
138
 
139
 
140
  output_y1 = (
 
158
  output_y3 = (
159
  """
160
  <br/>
161
+ <center>The probability of 18-month survival:</center>
162
  <br/>
163
  <center><h1>{:.2f}%</h1></center>
164
  """
 
167
  output_y4 = (
168
  """
169
  <br/>
170
+ <center>The probability of 24-month survival:</center>
171
  <br/>
172
  <center><h1>{:.2f}%</h1></center>
173
  """
 
177
  #Define predict for y1.
178
  def y1_predict(*args):
179
  df1 = pd.DataFrame([args], columns=x1.columns)
180
+ pos_pred = y1_calib_model.predict_proba(df1)
 
181
  prob = pos_pred[0][1]
182
  prob = 1-prob
183
  output = output_y1.format(prob * 100)
 
186
  #Define predict for y2.
187
  def y2_predict(*args):
188
  df2 = pd.DataFrame([args], columns=x2.columns)
189
+ pos_pred = y2_calib_model.predict_proba(df2)
 
190
  prob = pos_pred[0][1]
191
  prob = 1-prob
192
  output = output_y2.format(prob * 100)
 
195
  #Define predict for y3.
196
  def y3_predict(*args):
197
  df3 = pd.DataFrame([args], columns=x3.columns)
198
+ pos_pred = y3_calib_model.predict_proba(df3)
 
199
  prob = pos_pred[0][1]
200
  prob = 1-prob
201
  output = output_y3.format(prob * 100)
 
204
  #Define predict for y4.
205
  def y4_predict(*args):
206
  df4 = pd.DataFrame([args], columns=x4.columns)
207
+ pos_pred = y4_calib_model.predict_proba(df4)
 
208
  prob = pos_pred[0][1]
209
+ prob = 1-prob
210
  output = output_y4.format(prob * 100)
211
  return output
212
 
 
293
  return fig
294
 
295
 
296
+ with gr.Blocks(title = "NCDB-Meningioma") as demo:
297
 
298
  gr.Markdown(
299
  """
300
  <br/>
301
  <center><h2>NOT FOR CLINICAL USE</h2><center>
302
  <br/>
303
+ <center><h1>IDH-wt Glioblastoma Survival Outcomes</h1></center>
304
  <center><h2>Prediction Tool</h2></center>
305
  <br/>
306
  <center><h3>This web application should not be used to guide any clinical decisions.</h3><center>
 
326
  </tr>
327
  <tr>
328
  <td>6-Month Mortality</td>
329
+ <td>TabPFN</td>
330
+ <td>0.755 (0.733 - 0.777)</td>
331
+ <td>0.767 (0.745 - 0.789)</td>
332
+ <td>0.764 (0.742 - 0.786)</td>
333
+ <td>0.654 (0.630 - 0.678)</td>
334
+ <td>0.840 (0.811 - 0.857)</td>
335
+ <td>0.135 (0.117 - 0.153)</td>
336
  </tr>
337
  <tr>
338
  <td>12-Month Mortality</td>
339
+ <td>TabPFN</td>
340
+ <td>0.685 (0.661 - 0.709)</td>
341
+ <td>0.728 (0.705 - 0.751)</td>
342
+ <td>0.707 (0.683 - 0.731)</td>
343
+ <td>0.746 (0.723 - 0.769)</td>
344
+ <td>0.783 (0.752 - 0.800)</td>
345
+ <td>0.203 (0.182 - 0.224)</td>
346
  </tr>
347
  <tr>
348
+ <td>18-Month Mortality</td>
349
+ <td>TabPFN</td>
350
+ <td>0.706 (0.682 - 0.730)</td>
351
+ <td>0.659 (0.634 - 0.684)</td>
352
+ <td>0.689 (0.665 - 0.713)</td>
353
+ <td>0.832 (0.812 - 0.852)</td>
354
+ <td>0.749 (0.717 - 0.768)</td>
355
+ <td>0.193 (0.172 - 0.214)</td>
356
+ </tr>
357
  <tr>
358
+ <td>24-Month Mortality</td>
359
+ <td>TabPFN</td>
360
+ <td>0.732 (0.708 - 0.756)</td>
361
+ <td>0.716 (0.691 - 0.741)</td>
362
+ <td>0.728 (0.704 - 0.752)</td>
363
+ <td>0.925 (0.911 - 0.939)</td>
364
+ <td>0.780 (0.755 - 0.813)</td>
365
+ <td>0.141 (0.122 - 0.160)</td>
366
+ </tr>
367
  </table>
368
  </div>
369
  """
 
377
 
378
  Sex = gr.Dropdown(label = "Sex", choices = ['Male', 'Female'], type = 'index', value = 'Male')
379
 
380
+ Race = gr.Dropdown(label = "Race", choices = ['White', 'Black', 'Other'], type = 'index', value = 'White')
381
 
382
  Hispanic_Ethnicity = gr.Dropdown(label = "Hispanic Ethnicity", choices = ['No', 'Yes', 'Unknown'], type = 'index', value = 'No')
383
 
384
+ Insurance_Status = gr.Dropdown(label = "Insurance Status", choices = ['Private insurance', 'Medicare', 'Medicaid', 'Other government', 'Not insured', 'Unknown'], type = 'index', value = 'Private insurance')
385
 
386
+ Facility_Type = gr.Dropdown(label = "Facility Type", choices = ['Academic/Research Program', 'Community Cancer Program', 'Integrated Network Cancer Program'], type = 'index', value = 'Academic/Research Program')
387
 
388
+ Facility_Location = gr.Dropdown(label = "Facility Location", choices = ['Central', 'Atlantic', 'Pacific', 'Mountain', 'New England'], type = 'index', value = 'Central')
389
 
390
+ CharlsonDeyo_Score = gr.Dropdown(label = "Charlson-Deyo Score", choices = ['0', '1', '>2'], type = 'index', value = '0')
 
 
 
 
 
 
 
 
 
 
 
 
391
 
392
+ MGMT_Methylation = gr.Dropdown(label = "MGMT Methylation", choices = ['Unmethylated', 'Methylated'], type = 'index', value = 'Unmethylated')
393
 
394
+ Tumor_Size = gr.Dropdown(label = "Tumor Size (mm)", minimum = 1, maximum = 300, step = 1, value = 30)
 
 
395
 
396
+ Extent_of_Resection = gr.Dropdown(label = 'Extent of Resection', choices = ['No resective surgery was performed', 'Gross total resection'], type = 'index', value = 'Gross total resection')
397
 
398
+ Radiotherapy = gr.Dropdown(label = 'Radiotherapy', choices = ['No', 'Yes'], type = 'index', value = 'Yes')
399
 
400
+ Chemotherapy = gr.Dropdown(label = "Chemotherapy", choices = ['No', 'Yes'], type = 'index', value = 'Yes')
401
 
402
+ Immunotherapy = gr.Dropdown(label = "Immunotherapy", choices = ['No', 'Yes'], type = 'index', value = 'No')
 
 
403
 
404
  with gr.Column():
405
 
 
409
  """
410
  <center> <h2>6-Month Survival</h2> </center>
411
  <br/>
412
+ <center> This model uses the Random Forest algorithm.</center>
413
  <br/>
414
  """
415
  )
 
453
  """
454
  <center> <h2>12-Month Survival</h2> </center>
455
  <br/>
456
+ <center> This model uses the Random Forest algorithm.</center>
457
  <br/>
458
  """
459
  )
 
496
 
497
  gr.Markdown(
498
  """
499
+ <center> <h2> 18-Month Survival</h2> </center>
500
  <br/>
501
+ <center> This model uses the TabPFN algorithm.</center>
502
  <br/>
503
  """
504
  )
 
541
 
542
  gr.Markdown(
543
  """
544
+ <center> <h2> 24-Month Survival</h2> </center>
545
  <br/>
546
+ <center> This model uses the TabPFN algorithm.</center>
547
  <br/>
548
  """
549
  )
 
580
  """
581
  <br/>
582
  """
583
+ )
 
584
 
585
  y1_predict_btn.click(
586
  y1_predict,
587
+ inputs = [Age, Sex, Race, Hispanic_Ethnicity, Insurance_Status, Facility_Type, Facility_Location, CharlsonDeyo_Score, Tumor_Size, MGMT_Methylation, Extent_of_Resection, Radiotherapy, Chemotherapy, Immunotherapy],
588
  outputs = [label1]
589
  )
590
 
591
  y2_predict_btn.click(
592
  y2_predict,
593
+ inputs = [Age, Sex, Race, Hispanic_Ethnicity, Insurance_Status, Facility_Type, Facility_Location, CharlsonDeyo_Score, Tumor_Size, MGMT_Methylation, Extent_of_Resection, Radiotherapy, Chemotherapy, Immunotherapy],
594
  outputs = [label2]
595
  )
596
 
597
  y3_predict_btn.click(
598
  y3_predict,
599
+ inputs = [Age, Sex, Race, Hispanic_Ethnicity, Insurance_Status, Facility_Type, Facility_Location, CharlsonDeyo_Score, Tumor_Size, MGMT_Methylation, Extent_of_Resection, Radiotherapy, Chemotherapy, Immunotherapy],
600
  outputs = [label3]
601
  )
602
+
603
  y4_predict_btn.click(
604
  y4_predict,
605
+ inputs = [Age, Sex, Race, Hispanic_Ethnicity, Insurance_Status, Facility_Type, Facility_Location, CharlsonDeyo_Score, Tumor_Size, MGMT_Methylation, Extent_of_Resection, Radiotherapy, Chemotherapy, Immunotherapy],
606
  outputs = [label4]
607
+ )
608
 
609
  y1_interpret_btn.click(
610
  y1_interpret,
611
+ inputs = [Age, Sex, Race, Hispanic_Ethnicity, Insurance_Status, Facility_Type, Facility_Location, CharlsonDeyo_Score, Tumor_Size, MGMT_Methylation, Extent_of_Resection, Radiotherapy, Chemotherapy, Immunotherapy],
612
  outputs = [plot1],
613
  )
614
 
615
  y2_interpret_btn.click(
616
  y2_interpret,
617
+ inputs = [Age, Sex, Race, Hispanic_Ethnicity, Insurance_Status, Facility_Type, Facility_Location, CharlsonDeyo_Score, Tumor_Size, MGMT_Methylation, Extent_of_Resection, Radiotherapy, Chemotherapy, Immunotherapy],
618
  outputs = [plot2],
619
  )
620
 
621
  y3_interpret_btn.click(
622
  y3_interpret,
623
+ inputs = [Age, Sex, Race, Hispanic_Ethnicity, Insurance_Status, Facility_Type, Facility_Location, CharlsonDeyo_Score, Tumor_Size, MGMT_Methylation, Extent_of_Resection, Radiotherapy, Chemotherapy, Immunotherapy],
624
  outputs = [plot3],
625
  )
626
+
627
  y4_interpret_btn.click(
628
  y4_interpret,
629
+ inputs = [Age, Sex, Race, Hispanic_Ethnicity, Insurance_Status, Facility_Type, Facility_Location, CharlsonDeyo_Score, Tumor_Size, MGMT_Methylation, Extent_of_Resection, Radiotherapy, Chemotherapy, Immunotherapy],
630
  outputs = [plot4],
631
+ )
632
 
633
  gr.Markdown(
634
  """