mertkarabacak commited on
Commit
b799b72
1 Parent(s): 16543f1

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -1
app.py CHANGED
@@ -24,6 +24,8 @@ import shap
24
  import gradio as gr
25
  import random
26
  import re
 
 
27
 
28
 
29
  #Read data.
@@ -34,6 +36,12 @@ variables = ['Age', 'Sex', 'Ethnicity', 'Weight', 'Height', 'Systolic_Blood_Pres
34
  x1 = x1[variables]
35
 
36
 
 
 
 
 
 
 
37
  #Assign unique values as answer options.
38
  unique_SEX = ['Male', 'Female', 'Unknown']
39
  unique_RACE = ['White', 'Black', 'Asian', 'American Indian', 'Pacific Islander', 'Other', 'Unknown']
@@ -171,12 +179,21 @@ def y1_predict_rf(*args):
171
  return {"Mortality": float(pos_pred[0][1]), "No Mortality": float(pos_pred[0][0])}
172
 
173
 
 
 
 
 
 
 
 
 
 
174
  #Define interpret for y1 (mortality).
175
  def y1_interpret_xgb(*args):
176
  df = pd.DataFrame([args], columns=x1.columns)
177
  df = df.astype({col: "category" for col in categorical_columns1})
178
  shap_values = y1_explainer_xgb.shap_values(xgb.DMatrix(df, enable_categorical=True))
179
- fig_y1 = shap.bar_plot(shap_values[0], max_display = 10, show = False, feature_names = x1.columns)
180
  fig_y1 = plt.gcf()
181
  ax_y1 = plt.gca()
182
  fig_y1.set_figheight(6)
 
24
  import gradio as gr
25
  import random
26
  import re
27
+ import textwrap
28
+
29
 
30
 
31
  #Read data.
 
36
  x1 = x1[variables]
37
 
38
 
39
+ #Define feature names.
40
+ f_names = x1.columns
41
+ f_names = [f_names.replace('__', ' - ') for f in f_names]
42
+ f_names = [f_names.replace('_', ' ') for f in f_names]
43
+
44
+
45
  #Assign unique values as answer options.
46
  unique_SEX = ['Male', 'Female', 'Unknown']
47
  unique_RACE = ['White', 'Black', 'Asian', 'American Indian', 'Pacific Islander', 'Other', 'Unknown']
 
179
  return {"Mortality": float(pos_pred[0][1]), "No Mortality": float(pos_pred[0][0])}
180
 
181
 
182
+ def wrap_labels(ax, width, break_long_words=False):
183
+ labels = []
184
+ for label in ax.get_yticklabels():
185
+ text = label.get_text()
186
+ labels.append(textwrap.fill(text, width=width,
187
+ break_long_words=break_long_words))
188
+ ax.set_yticklabels(labels, rotation=0)
189
+
190
+
191
  #Define interpret for y1 (mortality).
192
  def y1_interpret_xgb(*args):
193
  df = pd.DataFrame([args], columns=x1.columns)
194
  df = df.astype({col: "category" for col in categorical_columns1})
195
  shap_values = y1_explainer_xgb.shap_values(xgb.DMatrix(df, enable_categorical=True))
196
+ fig_y1 = shap.bar_plot(shap_values[0], max_display = 10, show = False, feature_names = f_names)
197
  fig_y1 = plt.gcf()
198
  ax_y1 = plt.gca()
199
  fig_y1.set_figheight(6)