Spaces:
Sleeping
Sleeping
suppressPackageStartupMessages(library(dplyr)) | |
suppressPackageStartupMessages(library(ggplot2)) | |
suppressPackageStartupMessages(library(ggExtra)) | |
suppressPackageStartupMessages(library(forcats)) | |
range01 <- function(x){(x-min(x))/(max(x)-min(x))} | |
shap_summary_plot<-function(shap_values){ | |
summary_plot <- | |
shap_values %>% reshape2::melt() %>% group_by(class, variable) %>% | |
summarise(mean = mean(abs(value))) %>% | |
arrange(desc(mean)) %>% | |
ggplot() + | |
# ggdark::dark_theme_classic() + | |
theme_classic()+ | |
geom_col(aes( | |
y = variable, | |
x = mean, | |
group = class, | |
fill = class | |
), position = "stack") + | |
ylab("Feature")+ | |
xlab("Mean(|Shap Value|) Average impact on model output magnitude per activity.")+ | |
guides(fill=guide_legend(title="Activity")) | |
summary_plot | |
} | |
shap_summary_plot_perclass<-function(shap_values, class="G",color="#F8766D"){ | |
shap_values <-shap_values %>% as.data.frame() %>% filter(class == {{class}} ) | |
summary_plot <- | |
shap_values %>% reshape2::melt() %>% group_by(variable) %>% | |
summarise(mean = mean(abs(value))) %>% | |
ggplot() + | |
theme_classic()+ | |
geom_col(aes( | |
x = mean, | |
y = fct_reorder(variable,mean) | |
), | |
fill = color | |
) + | |
ylab("Feature")+ | |
xlab(paste0("Mean(|Shap Value|) Average impact on model output magnitude for activity ", class))+ | |
guides(fill=guide_legend(title="Activity")) | |
summary_plot | |
} | |
shap_beeswarm_plot<-function(shap_values,dataset){ | |
shap_values <- shap_values %>% reshape2::melt() | |
dataset<-dataset %>% mutate(class=Activity) %>% select(-Activity) %>% | |
reshape2::melt() %>% group_by(variable) %>% | |
mutate(value_scale=range01(value)) | |
beeswarm_plot<-cbind(shap_values, feature_value=dataset$value_scale) %>% # filter(class=="GM") %>% | |
ggplot()+ | |
facet_wrap(~class)+ | |
#ggdark::dark_theme_bw()+ | |
theme_classic()+ | |
geom_hline(yintercept=0, | |
color = "red", size=0.5)+ | |
ggforce::geom_sina(aes(x=variable,y=value,fill=feature_value),color="black", size=2.4,bins=4,alpha=0.9,shape=22)+ | |
scale_fill_gradient(low = "yellow", high = "red", na.value = NA)+ | |
scale_fill_gradient(low = "skyblue", high = "orange", na.value = NA)+ | |
xlab("Feature")+ylab("SHAP value")+ | |
theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1)) | |
beeswarm_plot | |
} | |
#' Dependency plot for a particular feature. The plot considers | |
#' activities and FP/TP | |
#' | |
#' @param feature a particular feature to calculate | |
#' @param dataset a dataset with goat information | |
#' @param shap a shap value dataset for each feature. | |
#' | |
#' @return a dependency plot for each activity considering the selected feature | |
#' @export ggplot object | |
#' | |
#' @examples | |
#' | |
#' dataset <- | |
#' readr::read_delim("data/split/seba-caprino_loocv.tsv", | |
#' delim = '\t') | |
#' selected_variables <- | |
#' readr::read_delim( | |
#' "data/topnfeatures/seba-caprino_selected_features.tsv", | |
#' col_types = cols(), | |
#' delim = '\t' | |
#' ) | |
#' dataset <- | |
#' dataset %>% select(selected_variables$variable, | |
#' Anim, | |
#' Activity) | |
#' goat_model <- readRDS("models/boost/seba-caprino_model.rds") | |
#' shap_values <- calculate_shap(dataset, | |
#' model = goat_model, | |
#' nsim = 30) | |
#' dependency_plot_full(feature = "Steps", | |
#' dataset = dataset, | |
#' shap = shap_values) | |
dependency_plot <- function(feature, dataset, shap) { | |
newdata <- dataset %>% mutate({{ feature }} := range01(!!sym(feature))) | |
#activities <- c("G", "GM", "W", "R") | |
activities<-dataset %>% pull(Activity) %>% unique() | |
plots <- list() | |
for (activity in activities) { | |
s <- shap[which(shap$class == activity), 1:18] | |
x <- newdata[which(newdata$Activity == activity), ] | |
data <- cbind( | |
shap = (s %>% as.data.frame %>% select(feature)), | |
feature = (x %>% select(feature)), | |
tp = x %>% mutate(tp = ifelse(Activity == predictions, "TP", "FP")) %>% | |
pull(tp) | |
) | |
names(data) <- c("shap", "feature", "tp") | |
p <- ggplot(data, aes(x = feature)) + | |
geom_point(aes(y = shap, color = tp), alpha = 0.3, size = 0.8) + | |
geom_smooth(aes(y = shap), | |
se = FALSE, | |
size = 0.5, | |
linetype = "dashed") + | |
geom_hline( | |
yintercept = 0, | |
color = 'red', | |
size = 0.5, | |
alpha = 0.5 | |
) + | |
xlab(feature) + | |
labs(title = paste0("Activity ", activity)) + | |
ylab("SHAP Value") + | |
ylim(-0.1, 0.4) + | |
xlim(0, 1) + | |
theme_light() + | |
theme(legend.position = 'none') | |
p1 <- | |
ggMarginal( | |
p, | |
type = "histogram", | |
fill = 'gray', | |
color = 'white', | |
size = 10, | |
xparams = list(bins = 25), | |
yparams = list(bins = 15) | |
) #,margins='x') | |
plots[[activity]] <- p1 | |
} | |
#plots | |
do.call(grid.arrange, c(plots, ncol = 4)) | |
} | |
#' Dependency plot for a particular feature on a particular animal. | |
#' The plot considers activities and FP/TP | |
#' | |
#' @param feature a particular feature to calculate | |
#' @param dataset a dataset with goat information | |
#' @param shap a shap value dataset for each feature. | |
#' @param anim the id of the animal | |
#' @return a dependency plot for each activity considering the selected feature | |
#' @export ggplot object | |
#' | |
#' @examples | |
#' | |
#' dataset <- | |
#' readr::read_delim("data/split/seba-caprino_loocv.tsv", | |
#' delim = '\t') | |
#' selected_variables <- | |
#' readr::read_delim( | |
#' "data/topnfeatures/seba-caprino_selected_features.tsv", | |
#' col_types = cols(), | |
#' delim = '\t' | |
#' ) | |
#' dataset <- | |
#' dataset %>% select(selected_variables$variable, | |
#' Anim, | |
#' Activity) | |
#' goat_model <- readRDS("models/boost/seba-caprino_model.rds") | |
#' shap_values <- calculate_shap(dataset, | |
#' model = goat_model, | |
#' nsim = 30) | |
#' dependency_plot_anim(feature = "Steps", | |
#' dataset = dataset, | |
#' shap = shap_values, | |
#' anim = 'a13') | |
dependency_plot_anim<- function(feature,dataset,shap,anim){ | |
newdata <- dataset %>% mutate({{feature}} := range01(!!sym(feature))) | |
plots<-list() | |
activities<-newdata %>% filter(Anim == anim) %>% pull(Activity) %>% unique() | |
for (activity in activities) { | |
s <- shap[which(shap$class == activity & | |
shap$Anim == anim | |
), 1:18] | |
x <- newdata[which(newdata$Activity == activity & | |
newdata$Anim == anim | |
),] | |
data <- cbind(shap=(s %>% as.data.frame %>% select(feature)), | |
feature = (x %>% select(feature)), | |
tp = x %>% mutate(tp=ifelse(Activity == predictions,"TP","FP")) %>% pull(tp) ) | |
names(data)<-c("shap","feature","tp") | |
p <- ggplot(data, aes(x = feature)) + | |
geom_point(aes(y = shap, color = tp), alpha = 0.3, size = 1.8) + | |
geom_smooth(aes(y = shap), | |
se = FALSE, | |
size = 0.5, | |
linetype = "dashed") + | |
geom_hline( | |
yintercept = 0, | |
color = 'red', | |
size = 0.5, | |
alpha = 0.5 | |
) + | |
xlab(feature) + | |
labs(title = paste0("Activity ", activity)) + | |
ylab("SHAP Value") + | |
ylim(-0.1, 0.4) + | |
xlim(0, 1) + | |
theme_light() + | |
theme(legend.position = 'none') | |
p1 <- | |
ggMarginal( | |
p, | |
type = "histogram", | |
fill = 'gray', | |
color = 'white', | |
size = 15, | |
xparams = list(bins = 25), | |
yparams = list(bins = 15) | |
) #,margins='x') | |
plots[[activity]] <- p1 | |
} | |
do.call(grid.arrange, c(plots, ncol = length(activities))) | |
} | |
#' contribution plot for SHAP values | |
#' | |
#' @param shap shap values for a particular class, animal, etc. | |
#' @param num_row the row number of the observation to show | |
#' | |
#' @return ggplot object | |
#' @export | |
#' | |
#' @examples | |
#' | |
#' shap_values_G <- calculate_shap_class( | |
#' dataset = dataset, | |
#' new_data = newdata, | |
#' model= model, | |
#' nsim = 100, | |
#' function_class = p_function_G, | |
#' class_name ="G") | |
#' p1 <- contribution_plot(shap_values_G,num_row = 1) + | |
#' labs(title="Anim a13: class G (FN)", subtitle = "SHAP analysis for class G") | |
#' | |
contribution_plot <-function(s, num_row = 1){ | |
s<-s[num_row,] | |
s <- data.frame( | |
Variable = names(s[,1:15]), | |
Importance = apply(s[,1:15], MARGIN = 2, FUN = function(x) sum(x)) | |
) | |
ggplot(s, aes(Variable, Importance, Importance,fill=Importance) )+ | |
geom_col() + | |
coord_flip() + | |
xlab("") + | |
ylab("Shapley value")+ | |
theme_classic()+ | |
theme(legend.position = 'none') | |
} | |
contribution_plot_w_feature <-function(s, f, num_row = 1){ | |
d <- data.frame( | |
variable = names(s[num_row,1:15]), | |
importance = apply(s[num_row,1:15], MARGIN = 2, FUN = function(x) sum(x)), | |
value = apply(f[num_row,1:15], MARGIN = 2, FUN = function(x) sum(x)) | |
) | |
ggplot(d, aes(variable, importance, value,fill=value) )+ | |
geom_col() + | |
geom_text(aes(label=round(value,digits = 2),hjust = 1.0),size=2)+ | |
coord_flip() + | |
xlab("") + | |
ylab("Shapley value")+ | |
scale_fill_gradient(low = 'lightgray', high = 'skyblue')+ | |
theme_classic()+ | |
theme(legend.position = 'none') | |
} | |