A general framework for constructing partial dependence (i.e., marginal effect) plots from various types machine learning models in R.
The primary purpose of this package is to provide a general framework for constructing partial dependence plots (PDPs) in R.
# Install latest release from CRAN (recommended)install.packages("pdp")# Install development version from GitHub repodevtools::install_github("bgreenwell/pdp")
The examples below demonstrate various usages of the
pdp package: regression, classification, and interfacing with the well-known
caret package. To start, we need to install a few additional packages that will be required to run the examples.
install.packages(c("caret", "ggmap", "kernlab", "randomForest"))
In this example, we fit a random forest to the Boston housing data. (See
?boston for a brief explanation of the data.) Note, for any of the following examples, you can see a progress bar by simply specifying
progress = "text" in the call to
partial. You may also reduce the computation time via the
grid.resolution option in the call to
# Load required packageslibrary(pdp) # for constructing PDPslibrary(randomForest) # for random forest algorithm# Load the Boston housing datadata (boston) # included with the pdp package# Fit a random forest to the boston housing dataset.seed(101) # for reproducibilityboston.rf <- randomForest(cmedv ~ ., data = boston)# Partial dependence of lstat and rm on cmedvgrid.arrange(partial(boston.rf, pred.var = "lstat", plot = TRUE, rug = TRUE),partial(boston.rf, pred.var = "rm", plot = TRUE, rug = TRUE),partial(boston.rf, pred.var = c("lstat", "rm"), plot = TRUE, chull = TRUE),ncol = 3)
In this example, we fit a support vector machine with a radial basis function kernel to the Pima Indians diabetes data. (See
?pima for a brief explanation of the data.)
# Load required packageslibrary(kernlab) # for fitting SVMs# Fit an SVM to the Pima Indians diabetes datadata (pima) # load the boston housing datapima.svm <- ksvm(diabetes ~ ., data = pima, type = "C-svc", kernel = "rbfdot",C = 0.5, prob.model = TRUE)# Partial dependence of glucose and age on diabetes test result (neg/pos).partial(pima.svm, pred.var = c("glucose", "age"), plot = TRUE, chull = TRUE,train = pima)
Finally, we demonstrate the construction of PDPs from models fit using the
caret is an extremetly useful package for classification and regression training that, essentially, has one function (
train) for fitting all kinds of predictive models in R (e.g.,
For illustration we use
train function to tune an XGBoost model to the Pima Indians diabetes data using 5-fold cross-validation. We then use the final model to construct PDPs for
age. Note, when training a model using
train function, you can view tuning progress by setting
verboseIter = TRUE in the call to
# Load required packageslibrary(caret) # for model training/tuning# Set up for 5-fold cross-validationctrl <- trainControl(method = "cv", number = 5)# Grid of tuning parameter valuesxgb.grid <- expand.grid(nrounds = 500,max_depth = 1:6,eta = c(0.001, 0.01, 0.1, 0.2, 0.3, 0.5),gamma = 0,colsample_bytree = 1,min_child_weight = 1,subsample = 1)# Tune a support vector machine (SVM) using a radial basis function kerel to# the Pima Indians diabetes data. This may take a few minutes!set.seed(103) # for reproducibilitypima.xgb <- train(diabetes ~ ., data = pima, method = "xgbTree",prob.model = TRUE, na.action = na.omit, trControl = ctrl,tuneGrid = xgb.grid)# Partial dependence of glucose and age on diabetes test result (neg/pos)grid.arrange(partial(pima.xgb, pred.var = "glucose", plot = TRUE, rug = TRUE),partial(pima.xgb, pred.var = "age", plot = TRUE, rug = TRUE),partial(pima.xgb, pred.var = "mass", plot = TRUE, rug = TRUE),ncol = 3)
...argument in the call to
partialnow refers to additional arguments to be passed onto
plyr::aaply. For example, using
"gbm"objects will require specification of
n.treeswhich can now simply be passed to
plyr-based progress bars),
foreach-based parallel execution), and
paropts(list of additional arguments passed onto
parallel = TRUE).
partialnow throws an informative error message when the
pred.gridargument refers to predictors not in the original training data.
randomForestis no longer imported.
caretpackage (i.e., objects of class
boston(corrected Boston housing data) and
pima(corrected Pima Indians diabetes data).
chull = TRUEcausing the convex hull to not be computed.
plotPartialto be more modular.
gbmsupport for most non-
randomForestis now imported.
partialnow makes sure each column of
pred.gridhas the correct class, levels, etc.
partialgained a new option,
levelplot, which defaults to
TRUE. The original option,
contour, has changed and now specifies whether or not to add contour lines whenever
levelplot = TRUE.