mertkarabacak
commited on
Commit
•
b799b72
1
Parent(s):
16543f1
Upload app.py
Browse files
app.py
CHANGED
@@ -24,6 +24,8 @@ import shap
|
|
24 |
import gradio as gr
|
25 |
import random
|
26 |
import re
|
|
|
|
|
27 |
|
28 |
|
29 |
#Read data.
|
@@ -34,6 +36,12 @@ variables = ['Age', 'Sex', 'Ethnicity', 'Weight', 'Height', 'Systolic_Blood_Pres
|
|
34 |
x1 = x1[variables]
|
35 |
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
#Assign unique values as answer options.
|
38 |
unique_SEX = ['Male', 'Female', 'Unknown']
|
39 |
unique_RACE = ['White', 'Black', 'Asian', 'American Indian', 'Pacific Islander', 'Other', 'Unknown']
|
@@ -171,12 +179,21 @@ def y1_predict_rf(*args):
|
|
171 |
return {"Mortality": float(pos_pred[0][1]), "No Mortality": float(pos_pred[0][0])}
|
172 |
|
173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
#Define interpret for y1 (mortality).
|
175 |
def y1_interpret_xgb(*args):
|
176 |
df = pd.DataFrame([args], columns=x1.columns)
|
177 |
df = df.astype({col: "category" for col in categorical_columns1})
|
178 |
shap_values = y1_explainer_xgb.shap_values(xgb.DMatrix(df, enable_categorical=True))
|
179 |
-
fig_y1 = shap.bar_plot(shap_values[0], max_display = 10, show = False, feature_names =
|
180 |
fig_y1 = plt.gcf()
|
181 |
ax_y1 = plt.gca()
|
182 |
fig_y1.set_figheight(6)
|
|
|
24 |
import gradio as gr
|
25 |
import random
|
26 |
import re
|
27 |
+
import textwrap
|
28 |
+
|
29 |
|
30 |
|
31 |
#Read data.
|
|
|
36 |
x1 = x1[variables]
|
37 |
|
38 |
|
39 |
+
#Define feature names.
|
40 |
+
f_names = x1.columns
|
41 |
+
f_names = [f_names.replace('__', ' - ') for f in f_names]
|
42 |
+
f_names = [f_names.replace('_', ' ') for f in f_names]
|
43 |
+
|
44 |
+
|
45 |
#Assign unique values as answer options.
|
46 |
unique_SEX = ['Male', 'Female', 'Unknown']
|
47 |
unique_RACE = ['White', 'Black', 'Asian', 'American Indian', 'Pacific Islander', 'Other', 'Unknown']
|
|
|
179 |
return {"Mortality": float(pos_pred[0][1]), "No Mortality": float(pos_pred[0][0])}
|
180 |
|
181 |
|
182 |
+
def wrap_labels(ax, width, break_long_words=False):
|
183 |
+
labels = []
|
184 |
+
for label in ax.get_yticklabels():
|
185 |
+
text = label.get_text()
|
186 |
+
labels.append(textwrap.fill(text, width=width,
|
187 |
+
break_long_words=break_long_words))
|
188 |
+
ax.set_yticklabels(labels, rotation=0)
|
189 |
+
|
190 |
+
|
191 |
#Define interpret for y1 (mortality).
|
192 |
def y1_interpret_xgb(*args):
|
193 |
df = pd.DataFrame([args], columns=x1.columns)
|
194 |
df = df.astype({col: "category" for col in categorical_columns1})
|
195 |
shap_values = y1_explainer_xgb.shap_values(xgb.DMatrix(df, enable_categorical=True))
|
196 |
+
fig_y1 = shap.bar_plot(shap_values[0], max_display = 10, show = False, feature_names = f_names)
|
197 |
fig_y1 = plt.gcf()
|
198 |
ax_y1 = plt.gca()
|
199 |
fig_y1.set_figheight(6)
|