How parsnip works

  tidymodels, parsnip

  Max Kuhn, Davis Vaughan, Alex Hayes

parsnip was accepted to CRAN late last year. Our first blog post was a top-level introduction. Here, we will go into the design of some of the internals. If you end up looking at the package sources, there are going to be times that you might question our sanity. This post is intended to go over some design decisions and provide some context.

My first solution: caret

Having developed caret in early 2005, I was not prepared for what it would be like to maintain a package that wrapped more than 200 models. What caret does is to create code modules that define the different aspects of modeling, such as model fitting, prediction, and so on. These code modules are contained in lists, one per specific model. For example, here is one of the simplest lists for MASS::lda (slightly abbreviated):

list(
  label = "Linear Discriminant Analysis",
  library = "MASS",
  type = "Classification",
  fit = function(x, y, wts, param, lev, last, classProbs, ...)
    MASS::lda(x, y, ...)  ,
  predict = function(modelFit, newdata, submodels = NULL)
    predict(modelFit, newdata)$class,
  prob = function(modelFit, newdata, submodels = NULL)
    predict(modelFit, newdata)$posterior
  )

Some notes:

  1. ✔ The modules are fairly well-defined and mutually exclusive.
  2. ✔ The structure prevents information leakage. The fitting parts are isolated from the prediction parts and the prediction modules only return predictions in a standard format. caret compartmentalizes the modeling process and prevents people from “teaching to the test” unless they deliberately go out of their way to do so.
  3. ✔ These modules cleanly document issues with using the packages. We’d like each model function to have a predictable interface for fitting and prediction, but that’s not the reality for many packages. If nothing else, these modules document the idiosyncratic nature of some packages.
  4. ✔ The lda code above is a fairly clean case. Others are not as nice to look at, especially if they can predict multiple sub-models at once or have problematic interfaces.
  5. ✖ There is a lot of code duplication. For this reason, the directory containing the current set of 238 modules is about 1.3MB.
  6. ✖ The system in caret rigidly defines what parameters can be tuned. It is a bit of a walled garden in this respect.
  7. ✔ Code like MASS::lda would be exposed to R CMD check if it is contained in a package’s R directory. Because of this, there was a large number of package dependencies in early versions. caret was almost intolerable when it came time for CRAN to check the package. One way that I fixed this was to compile these code modules in an R list object and treat that as data in the package. In this way, the package R files do not contain much specific model code and, many formal dependencies are avoided.
  8. ✔ The ellipses (...) are heavily utilized here. This makes passing other arguments to the underlying fitting function trivial. This is probably still my favorite thing about the S language.
  9. ✖ In some cases, caret needs to grab an object that might be in the ellipses in order to modify the value. For example, the main tuning parameters for rpart models are in rpart.control. If a user passes in an argument by the name of control, it will need to be captured and the appropriate arguments (like cp, maxdepth, or minsplit) are modified without changing the other arguments. That’s not hard to do but it eliminates the benefits you get by using the ellipses. In the end, do.call("rpart", args) is used to fit the model. The downside of this is that the data objects are embedded in args and, as an unhappy side-effect, the data set gets embedded in rpart's call object. That’s really bad.

When I began at RStudio, I had already been thinking about a different and, hopefully more elegant, way to do this for tidymodels.

A focus on calls and quosures

When parsnip fits a model, it constructs a call object that will be evaluated to create the model fit object (rlang::call2 and rlang::call_modify are excellent). For example, if we were doing this “by-hand” for something simple like glm, an initial function using rlang could be:

library(rlang)
glm_fit <- function(formula, data, evaluate = FALSE, ...) {
  # capture the specific arguments and ellipses as quosures:
  args <- list(formula = enquo(formula), data = rlang::enquo(data))
  # capture any extra arguments
  args <- c(args, rlang::enquos(...))
  # make the call
  model_call <- rlang::call2("glm", .ns = "stats", !!!args)

  if (evaluate) {
    res <- rlang::eval_tidy(model_call)
  } else {
    res <- model_call
  }
  res
}

glm_fit()
#> stats::glm(formula = ~, data = ~)
glm_fit(family = stats::binomial)
#> stats::glm(formula = ~, data = ~, family = ~stats::binomial)

When these are printed, the tilde indicates which values are quosures. A quosure is a combination of an rlang expression and a reference to the environment in which it originated. Since it is partly an expression, the value has not yet been evaluated inside of glm_fit(). For this reason, when we pass family = stats::binomial, the object stats::binomial is not evaluated. If it were, the value of that object would be embedded into the call (type unclass(binomial()) at an R prompt to see what this looks like).

Here’s a better example:

glm_fit(mpg ~ ., data = mtcars, x = FALSE)
#> stats::glm(formula = ~(mpg ~ .), data = ~mtcars, x = ~FALSE)

The eval_tidy function can evaluate the quosure arguments when the call itself is evaluated. We get our model object as expected:

glm_fit(mpg ~ ., data = mtcars, x = FALSE, evaluate = TRUE)
#> 
#> Call:  stats::glm(formula = ~(mpg ~ .), data = ~mtcars, x = ~FALSE)
#> 
#> Coefficients:
#> (Intercept)          cyl         disp           hp         drat           wt  
#>     12.3034      -0.1114       0.0133      -0.0215       0.7871      -3.7153  
#>        qsec           vs           am         gear         carb  
#>      0.8210       0.3178       2.5202       0.6554      -0.1994  
#> 
#> Degrees of Freedom: 31 Total (i.e. Null);  21 Residual
#> Null Deviance:	    1130 
#> Residual Deviance: 147 	AIC: 164

For parsnip, there are utility functions that create the call and others to evaluate it. This means that the only model-related information needed to by the package to define the general glm fitting function would be:

glm_data <- list(func = "glm", package = "stats", required_args = c("formula", "data"))

All of the other arguments are assembled from the model specification object or extra arguments provided to set_engine() and fit(). This allows for a fairly compact representation of the information for the fitting module. Also, it does not expose the code in a way that requires a package dependency and doesn’t contaminate the model call with actual object values. The same strategy can be used to produce predictions and other quantities.

There are some other niceties too. If a package, like glmnet has a non-formula interface and requires the predictors to be in a matrix, that fitting function can just insert x = as.matrix(x) into the glmnet call instead of doing the matrix conversion prior to the model fit.

Unsurprisingly, most model are more complex than our glm example. There are some modeling packages, such as keras or xgboost, that don’t have a one-line call to fit the model (see the keras regression example here). In this case, a wrapper function is needed and this most likely results in an extra package dependency.

Also, some prediction methods give back results that require post-processing. For example, class probability predictions for a multiclass xgboost model come back as a vector. For example, if you were to predict four samples of iris data, you would get a 12 element vector back that requires you to reshape the results into the appropriate 4x3 data frame. parsnip handles these by having slots for pre-processing the data and/or post-processing the raw prediction results. More information can be found on the vignette for creating a model object. Keep in mind that, as we and other contributors work with the package more, these internals may slightly change in the first few versions.

Also, a lot of the package code revolves around getting the arguments right. There are some default arguments set by the package, such as family = binomial for logistic regression. These defaults can be overwritten but may also depend on the mode of the model (e.g. regression, classification, etc.). There are also some arguments that parsnp protects in case the user tries to modify them (e.g. data). Finally, the main arguments to parsnip model functions are standardized and need to eventually be converted back to their engine-specific names.

The down-side to using calls

Suppose that you are fitting a random forest regression model and want to tune over mtry, the number of randomly selected predictors to evaluate at each split of the tree. This is a function of the data since it depends on the number of predictors. A simplistic version of the code that iterates over mtry might look like:

library(tidymodels)
rf_model <- rand_forest(mode = "regression")

for (i in 1:4) {
  new_rf <- 
    rf_model %>% 
    update(mtry = i)
  rf_fit <- 
    new_rf %>%
    set_engine("ranger") %>%
    fit(Sepal.Width ~ ., data = iris)
  # now evaluate the model for performance
}

# what is the current value of `mtry`? 
new_rf
#> Random Forest Model Specification (regression)
#> 
#> Main Arguments:
#>   mtry = i

The specification depends on i and its value is not in new_rf, only a reference to a symbol i in the global environment. What if the value of i changes?

i <- 300
new_rf %>%
  set_engine("ranger") %>%
  fit(Sepal.Width ~ ., data = iris)
#> Error in ranger::ranger(formula = formula, data = data, mtry = ~i, num.threads = 1, : User interrupt or internal error.
# mtry should be between 1 and 4 for these data. 

In this context, using quosures for parameter values is problematic because (by default) their values are not frozen when the specification is created. However, we can fix this using quasiquotiation. The value of i can be embedded into the call using !!:

for (i in 1:4) {
  new_rf <- 
    rf_model %>% 
    update(mtry = !!i)
}

# What is the current value of `mtry` now? 
new_rf
#> Random Forest Model Specification (regression)
#> 
#> Main Arguments:
#>   mtry = 4

One other downside is related to the data used in the call. In the random forest example above, the value of mtry is dependent on the number of predictors. What if you don’t know how many predictors you have until the model is fit? How could you then specify mtry?

This could happen for a few reasons. For example, if you are using a recipe that has a filter based on reducing correlation, the number of predictors may not be known until the recipe is prepped (and the number may vary inside of resampling). A more common example is related to dummy variables. If I use fit_xy() instead of fit() in the code above, the Species predictor in the iris data is expanded from one column into two dummy variable columns just prior to the model fit. This could affect the possible range of mtry values.

There are a few ways to get around this. The first (and worst) is to use the data in an expression. Since mtry isn’t evaluated until the model fit, you could try to use mtry = floor(sqrt(ncol(data))). That’s very brittle for a few different reasons. For one, you may use the parsnip interface fit(formula, data), but the underlying model may be model_fn(x, y) and the data object doesn’t exist when the model call is evaluated.

For this reason, we added data descriptors to parsnip. These are small functions that only work when the model is being fit and they capture relevant aspects of the data at the time of fit. For example, the function .obs() can be used in an argument value to reference the number of rows in the data. To illustrate, the argument min_n for a random forest model corresponds to how many data points are required to make further splits. For regression models, this defaults to 5 when using ranger. Suppose you want this to be one-tenth of the data but at least 8. If your data are being resampled, you might not know the training set size. You could use:

rf_model %>% 
    update(min_n = max(8, floor(.obs()/10))) %>%
    set_engine("ranger") %>%
    fit(Sepal.Width ~ ., data = iris)
#> parsnip model object
#> 
#> Ranger result
#> 
#> Call:
#>  ranger::ranger(formula = formula, data = data, min.node.size = ~max(8,      floor(.obs()/10)), num.threads = 1, verbose = FALSE, seed = sample.int(10^5,      1)) 
#> 
#> Type:                             Regression 
#> Number of trees:                  500 
#> Sample size:                      150 
#> Number of independent variables:  4 
#> Mtry:                             2 
#> Target node size:                 15 
#> Variable importance mode:         none 
#> Splitrule:                        variance 
#> OOB prediction error (MSE):       0.0827 
#> R squared (OOB):                  0.565
# The math checks out for the "Target node size" above:
max(8, floor(nrow(iris)/10))
#> [1] 15
# Does this work outside of the function call? 
max(8, floor(.obs()/10))
#> Descriptor context not set
# Nope!

The different descriptors should enable a wide variety of values for tuning parameters.

Why not just evaluate the arguments?

In many cases, the model specification arguments (e.g. mtry or min_n) and engine arguments are simple objects or scalar values. It makes sense to quote data, x or y but why not just evaluate the other arguments as usual?

There are a few reasons. First, there are some arguments whose evaluation should be deferred. For example, stan and ranger models have their own random seed arguments. To enable reproducibility, parsnip gives these function default values of seed = sample.int(10^5, 1). If this argument were unquoted, then the seed value would be fixed when the package was compiled. There are solutions for this simple example though.

As seen above for data descriptors, there is the need to wait for some argument values to be evaluated at the same time that the call is evaluated. Originally, parsnip immediately evaluated almost all of the arguments and our advice was to have users quote special arguments using rlang::expr(). The feedback on this aspect of parsnip was uniformity unfavorable since it would require many casual users to learn rlang and metaprogramming techniques. For this reason, we moved the metaprogramming parts within the function to accomplish the same goals but without the user being required to understand the technical minutiae. When a user uses an argument like min_n = max(8, floor(.obs()/10)), the use of quosures is hidden from view and it looks like they are using an ordinary function called .obs().

One final reason to leave arguments unevaluated in the model specification is related to future plans. As previously mentioned, caret rigidly defined which model parameters were available for performance tuning. The approach taken by parsnip is very much the opposite. We want to enable users to tune any aspect of the model that they see fit, including some of the engine specific parameters.

For example, when fitting a Bayesian regression model, a user might want to tune over how diffuse the prior distribution should be and so on. Rather than formally defining every possible tunable parameter, parsnip allows the user to have a placeholder for parameters in the model specification that declares, “I want to change this parameter, but I don’t know what the exact value should be.”

To do this, a special function called varying() is used. A model cannot be fit if it has any varying parameters but future packages will be able to detect any of these parameters and construct tuning grids accordingly. The parameter values can be changed to specific candidate values, and these are then tested to see which value is the most appropriate. The code to do this is not ready at this point and will be part of another package. However, we can demonstrate how this happens inside of parsnip:

lr_stan_spec <-
  logistic_reg() %>%
  set_engine(
    "stan",
    iter = 5000,
    prior_intercept = rstanarm::student_t(df = varying()),
    seed = 2347
  )

# Which, if any, arguments, should be tuned?
varying_args(lr_stan_spec)
#> # A tibble: 5 x 4
#>   name            varying id           type      
#>   <chr>           <lgl>   <chr>        <chr>     
#> 1 penalty         FALSE   logistic_reg model_spec
#> 2 mixture         FALSE   logistic_reg model_spec
#> 3 iter            FALSE   logistic_reg model_spec
#> 4 prior_intercept TRUE    logistic_reg model_spec
#> 5 seed            FALSE   logistic_reg model_spec
nnet_spec <-
  mlp(
    hidden_units = varying(),
    epochs = varying(),
    dropout = varying()
  ) %>%
  set_engine(
    "keras", 
    batch_size = varying(),
    callbacks = callback_early_stopping(monitor = 'loss', min_delta = varying())
  )

varying_args(nnet_spec)
#> # A tibble: 7 x 4
#>   name         varying id    type      
#>   <chr>        <lgl>   <chr> <chr>     
#> 1 hidden_units TRUE    mlp   model_spec
#> 2 penalty      FALSE   mlp   model_spec
#> 3 dropout      TRUE    mlp   model_spec
#> 4 epochs       TRUE    mlp   model_spec
#> 5 activation   FALSE   mlp   model_spec
#> 6 batch_size   TRUE    mlp   model_spec
#> 7 callbacks    TRUE    mlp   model_spec

In the keras example, the argument names in nnet_spec names match up with objects in the dials package, and it will be possible to automatically create tuning grids for all of the parameters.

By avoiding evaluation in the model specification, we are enabling some interesting upcoming features.

Also, this feature will also be used with recipes and other future packages. This will enable joint optimization of parameters associated with pre-processing, model fitting, and post-processing activities.