Function plot.variable_dropout_explainer plots dropouts for variables used in the model. It uses output from variable_dropout function that corresponds to permutation based measure of variable importance.

# S3 method for variable_dropout_explainer
plot(x, ..., max_vars = 10)

Arguments

x

a variable fropout exlainer produced with the 'variable_dropout' function

...

other explainers that shall be plotted together

max_vars

maximum number of variables that shall be presented for for each model

Value

a ggplot2 object

Examples

library("randomForest") HR_rf_model <- randomForest(left~., data = breakDown::HR_data, ntree = 100)
#> Warning: The response has five or fewer unique values. Are you sure you want to do regression?
explainer_rf <- explain(HR_rf_model, data = HR_data, y = HR_data$left) vd_rf <- variable_dropout(explainer_rf, type = "raw") vd_rf
#> variable dropout_loss label #> 1 _full_model_ 3.338760 randomForest #> 2 left 3.338760 randomForest #> 3 promotion_last_5years 3.398822 randomForest #> 4 Work_accident 4.960937 randomForest #> 5 salary 7.179975 randomForest #> 6 sales 11.828655 randomForest #> 7 time_spend_company 76.371993 randomForest #> 8 average_montly_hours 99.113904 randomForest #> 9 last_evaluation 100.529014 randomForest #> 10 number_project 105.806632 randomForest #> 11 satisfaction_level 170.634216 randomForest #> 12 _baseline_ 341.674293 randomForest
plot(vd_rf)
HR_glm_model <- glm(left~., data = breakDown::HR_data, family = "binomial") explainer_glm <- explain(HR_glm_model, data = HR_data, y = HR_data$left) logit <- function(x) exp(x)/(1+exp(x)) vd_glm <- variable_dropout(explainer_glm, type = "raw", loss_function = function(observed, predicted) sum((observed - logit(predicted))^2)) vd_glm
#> variable dropout_loss label #> 1 _full_model_ 125.9635 lm #> 2 left 125.9635 lm #> 3 last_evaluation 125.9643 lm #> 4 promotion_last_5years 126.4027 lm #> 5 sales 126.9345 lm #> 6 average_montly_hours 127.3949 lm #> 7 time_spend_company 128.4438 lm #> 8 salary 133.9973 lm #> 9 number_project 134.0903 lm #> 10 Work_accident 136.1615 lm #> 11 satisfaction_level 182.1601 lm #> 12 _baseline_ 215.5256 lm
plot(vd_glm)
library("xgboost") model_martix_train <- model.matrix(left~.-1, breakDown::HR_data) data_train <- xgb.DMatrix(model_martix_train, label = breakDown::HR_data$left) param <- list(max_depth = 2, eta = 1, silent = 1, nthread = 2, objective = "binary:logistic", eval_metric = "auc") HR_xgb_model <- xgb.train(param, data_train, nrounds = 50) explainer_xgb <- explain(HR_xgb_model, data = model_martix_train, y = HR_data$left, label = "xgboost") vd_xgb <- variable_dropout(explainer_xgb, type = "raw") vd_xgb
#> variable dropout_loss label #> 1 _full_model_ 19.21844 xgboost #> 2 salesRandD 18.89531 xgboost #> 3 salesproduct_mng 19.17630 xgboost #> 4 salessales 19.20264 xgboost #> 5 salesaccounting 19.21844 xgboost #> 6 salesIT 19.21844 xgboost #> 7 salesmanagement 19.21844 xgboost #> 8 salesmarketing 19.21844 xgboost #> 9 saleshr 19.23792 xgboost #> 10 promotion_last_5years 19.40630 xgboost #> 11 salessupport 19.47122 xgboost #> 12 salestechnical 19.68104 xgboost #> 13 salarymedium 20.13957 xgboost #> 14 Work_accident 20.44466 xgboost #> 15 salarylow 20.46020 xgboost #> 16 average_montly_hours 45.56446 xgboost #> 17 number_project 48.57528 xgboost #> 18 last_evaluation 52.17555 xgboost #> 19 time_spend_company 74.85348 xgboost #> 20 satisfaction_level 152.05913 xgboost #> 21 _baseline_ 344.85452 xgboost
plot(vd_xgb)
plot(vd_rf, vd_glm, vd_xgb)
# NOTE: # if you like to have all importances hooked to 0, you can do this as well vd_rf <- variable_dropout(explainer_rf, type = "difference") vd_glm <- variable_dropout(explainer_glm, type = "difference", loss_function = function(observed, predicted) sum((observed - logit(predicted))^2)) vd_xgb <- variable_dropout(explainer_xgb, type = "difference") plot(vd_rf, vd_glm, vd_xgb)