Partial Dependence Plots

A general framework for constructing partial dependence (i.e., marginal effect) plots from various types machine learning models in R.


The primary purpose of this package is to provide a general framework for constructing partial dependence plots (PDPs) in R.

The R package pdp is available from CRAN; the development version is hosted on GitHub. There are two ways to install:

# Install latest release from CRAN (recommended)
install.packages("pdp")
 
# Install development version from GitHub repo
devtools::install_github("bgreenwell/pdp")

The examples below demonstrate various usages of the pdp package: regression, classification, and interfacing with the well-known caret package. To start, we need to install a few additional packages that will be required to run the examples.

install.packages(c("caret", "ggmap", "kernlab", "randomForest"))

In this example, we fit a random forest to the Boston housing data. (See ?boston for a brief explanation of the data.) Note, for any of the following examples, you can see a progress bar by simply specifying progress = "text" in the call to partial. You may also reduce the computation time via the grid.resolution option in the call to partial.

# Load required packages
library(pdp)  # for constructing PDPs
library(randomForest)  # for random forest algorithm
 
# Load the Boston housing data
data (boston)  # included with the pdp package
 
# Fit a random forest to the boston housing data
set.seed(101)  # for reproducibility
boston.rf <- randomForest(cmedv ~ ., data = boston)
 
# Partial dependence of lstat and rm on cmedv
grid.arrange(
  partial(boston.rf, pred.var = "lstat", plot = TRUE, rug = TRUE),
  partial(boston.rf, pred.var = "rm", plot = TRUE, rug = TRUE),
  partial(boston.rf, pred.var = c("lstat", "rm"), plot = TRUE, chull = TRUE),
  ncol = 3
)

In this example, we fit a support vector machine with a radial basis function kernel to the Pima Indians diabetes data. (See ?pima for a brief explanation of the data.)

# Load required packages
library(kernlab)  # for fitting SVMs
 
# Fit an SVM to the Pima Indians diabetes data
data (pima)  # load the boston housing data
pima.svm <- ksvm(diabetes ~ ., data = pima, type = "C-svc", kernel = "rbfdot",
                 C = 0.5, prob.model = TRUE)
 
# Partial dependence of glucose and age on diabetes test result (neg/pos). 
partial(pima.svm, pred.var = c("glucose", "age"), plot = TRUE, chull = TRUE,
        train = pima)

Finally, we demonstrate the construction of PDPs from models fit using the caret package; caret is an extremetly useful package for classification and regression training that, essentially, has one function (train) for fitting all kinds of predictive models in R (e.g., glmnet, svm, xgboost, etc.).

For illustration we use caret's train function to tune an XGBoost model to the Pima Indians diabetes data using 5-fold cross-validation. We then use the final model to construct PDPs for glucose and age. Note, when training a model using caret's train function, you can view tuning progress by setting verboseIter = TRUE in the call to trainControl.

# Load required packages
library(caret)  # for model training/tuning
 
# Set up for 5-fold cross-validation
ctrl <- trainControl(method = "cv", number = 5)
 
# Grid of tuning parameter values
xgb.grid <- expand.grid(
  nrounds = 500,
  max_depth = 1:6,
  eta = c(0.001, 0.01, 0.1, 0.2, 0.3, 0.5),
  gamma = 0, 
  colsample_bytree = 1,
  min_child_weight = 1,
  subsample = 1
)
 
# Tune a support vector machine (SVM) using a radial basis function kerel to
# the Pima Indians diabetes data. This may take a few minutes!
set.seed(103)  # for reproducibility
pima.xgb <- train(diabetes ~ ., data = pima, method = "xgbTree",
                  prob.model = TRUE, na.action = na.omit, trControl = ctrl,
                  tuneGrid = xgb.grid)
 
# Partial dependence of glucose and age on diabetes test result (neg/pos)
grid.arrange(
  partial(pima.xgb, pred.var = "glucose", plot = TRUE, rug = TRUE),
  partial(pima.xgb, pred.var = "age", plot = TRUE, rug = TRUE),
  partial(pima.xgb, pred.var = "mass", plot = TRUE, rug = TRUE),
  ncol = 3 
)

News

NEWS for pdp package

  • The ... argument in the call to partial now refers to additional arguments to be passed onto stats::predict rather than plyr::aaply. For example, using partial with "gbm" objects will require specification of n.trees which can now simply be passed to partial via the ... argument.
  • Added the following arguments to partial: progress (plyr-based progress bars), parallel (plyr/foreach-based parallel execution), and paropts (list of additional arguments passed onto foreach when parallel = TRUE).
  • Various bug fixes.
  • partial now throws an informative error message when the pred.grid argument refers to predictors not in the original training data.
  • The column name for the predicted value has been changed from "y" to "yhat".
  • randomForest is no longer imported.
  • Added support for the caret package (i.e., objects of class "train").
  • Added example datasets: boston (corrected Boston housing data) and pima (corrected Pima Indians diabetes data).
  • Fixed error that sometimes occurred when chull = TRUE causing the convex hull to not be computed.
  • Refactored plotPartial to be more modular.
  • Added gbm support for most non-"binomial" families`.
  • randomForest is now imported.
  • Added examples.
  • Fixed a non canonical CRAN URL in the README file.
  • partial now makes sure each column of pred.grid has the correct class, levels, etc.
  • partial gained a new option, levelplot, which defaults to TRUE. The original option, contour, has changed and now specifies whether or not to add contour lines whenever levelplot = TRUE.
  • Fixed a number of URLs.
  • More thorough documentation.
  • Fixed a couple of URLs and typos.
  • Added more thorough documentation.
  • Added support for C5.0, Cubist, nonlinear least squares, and XGBoost models.
  • Initial release.

Reference manual

It appears you don't have a PDF plugin for this browser. You can click here to download the reference manual.