New interface to validation splits

  tidymodels, tune, rsample

  Hannah Frick

We’re chuffed to announce the release of a new interface to validation splits in rsample 1.2.0 and tune 1.1.2. The rsample package makes it easy to create resamples for assessing model performance. The tune package facilitates hyperparameter tuning for the tidymodels packages.

You can install the new versions from CRAN with:

install.packages(c("rsample", "tune"))

This blog post will walk you through how to make a validation split and use it for tuning.

You can see a full list of changes in the release notes for rsample and tune.

Let’s start with loading the tidymodels package which will load, among others, both rsample and tune.

library(tidymodels)
#> ── Attaching packages ────────────────────────────────────── tidymodels 1.1.1 ──
#>  broom        1.0.5      recipes      1.0.7
#>  dials        1.2.0      rsample      1.2.0
#>  dplyr        1.1.2      tibble       3.2.1
#>  ggplot2      3.4.3      tidyr        1.3.0
#>  infer        1.0.4      tune         1.1.2
#>  modeldata    1.2.0      workflows    1.1.3
#>  parsnip      1.1.1      workflowsets 1.0.1
#>  purrr        1.0.2      yardstick    1.2.0
#> ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
#>  purrr::discard() masks scales::discard()
#>  dplyr::filter()  masks stats::filter()
#>  dplyr::lag()     masks stats::lag()
#>  recipes::step()  masks stats::step()
#>  Use suppressPackageStartupMessages() to eliminate package startup messages

The new functions

You can now make a three-way split of your data instead of doing a sequence of two binary splits.

  • initial_validation_split() with variants initial_validation_time_split() and group_initial_validation_split() for the initial three-way split
  • validation_set() to create the rset for tuning containing the analysis (= training) and assessment (= validation) set
  • training(), validation(), and testing() for access to the separate subsets
  • last_fit() (and fit_best()) now also work on the initial three-way split

The new functions in action

To illustrate how to use the new functions, we’ll replicate an analysis of childcare cost from a Tidy Tuesday done by Julia Silge in one of her screencasts.

We are modeling the median weekly price for school-aged kids in childcare centers mcsa and are thus removing the other variables containing different variants of median prices (e.g., for different age groups). We are also removing the FIPS code identifying the county as we are including various characteristics of the counties instead of their ID.

library(readr)
#> 
#> Attaching package: 'readr'
#> The following object is masked from 'package:yardstick':
#> 
#>     spec
#> The following object is masked from 'package:scales':
#> 
#>     col_factor

childcare_costs <- read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2023/2023-05-09/childcare_costs.csv')
#> Rows: 34567 Columns: 61
#> ── Column specification ────────────────────────────────────────────────────────
#> Delimiter: ","
#> dbl (61): county_fips_code, study_year, unr_16, funr_16, munr_16, unr_20to64...
#> 
#>  Use `spec()` to retrieve the full column specification for this data.
#>  Specify the column types or set `show_col_types = FALSE` to quiet this message.

childcare_costs <- childcare_costs |>
  select(-matches("^mc_|^mfc")) |>
  select(-county_fips_code) |>
  drop_na() 

glimpse(childcare_costs)
#> Rows: 23,593
#> Columns: 53
#> $ study_year                <dbl> 2008, 2009, 2010, 2011, 2012, 2013, 2014, 20…
#> $ unr_16                    <dbl> 5.42, 5.93, 6.21, 7.55, 8.60, 9.39, 8.50, 7.…
#> $ funr_16                   <dbl> 4.41, 5.72, 5.57, 8.13, 8.88, 10.31, 9.18, 8…
#> $ munr_16                   <dbl> 6.32, 6.11, 6.78, 7.03, 8.29, 8.56, 7.95, 6.…
#> $ unr_20to64                <dbl> 4.6, 4.8, 5.1, 6.2, 6.7, 7.3, 6.8, 5.9, 4.4,…
#> $ funr_20to64               <dbl> 3.5, 4.6, 4.6, 6.3, 6.4, 7.6, 6.8, 6.1, 4.6,…
#> $ munr_20to64               <dbl> 5.6, 5.0, 5.6, 6.1, 7.0, 7.0, 6.8, 5.9, 4.3,…
#> $ flfpr_20to64              <dbl> 68.9, 70.8, 71.3, 70.2, 70.6, 70.7, 69.9, 68…
#> $ flfpr_20to64_under6       <dbl> 66.9, 63.7, 67.0, 66.5, 67.1, 67.5, 65.2, 66…
#> $ flfpr_20to64_6to17        <dbl> 79.59, 78.41, 78.15, 77.62, 76.31, 75.91, 75…
#> $ flfpr_20to64_under6_6to17 <dbl> 60.81, 59.91, 59.71, 59.31, 58.30, 58.00, 57…
#> $ mlfpr_20to64              <dbl> 84.0, 86.2, 85.8, 85.7, 85.7, 85.0, 84.2, 82…
#> $ pr_f                      <dbl> 8.5, 7.5, 7.5, 7.4, 7.4, 8.3, 9.1, 9.3, 9.4,…
#> $ pr_p                      <dbl> 11.5, 10.3, 10.6, 10.9, 11.6, 12.1, 12.8, 12…
#> $ mhi_2018                  <dbl> 58462.55, 60211.71, 61775.80, 60366.88, 5915…
#> $ me_2018                   <dbl> 32710.60, 34688.16, 34740.84, 34564.32, 3432…
#> $ fme_2018                  <dbl> 25156.25, 26852.67, 27391.08, 26727.68, 2796…
#> $ mme_2018                  <dbl> 41436.80, 43865.64, 46155.24, 45333.12, 4427…
#> $ total_pop                 <dbl> 49744, 49584, 53155, 53944, 54590, 54907, 55…
#> $ one_race                  <dbl> 98.1, 98.6, 98.5, 98.5, 98.5, 98.6, 98.7, 98…
#> $ one_race_w                <dbl> 78.9, 79.1, 79.1, 78.9, 78.9, 78.3, 78.0, 77…
#> $ one_race_b                <dbl> 17.7, 17.9, 17.9, 18.1, 18.1, 18.4, 18.6, 18…
#> $ one_race_i                <dbl> 0.4, 0.4, 0.3, 0.2, 0.3, 0.3, 0.4, 0.4, 0.4,…
#> $ one_race_a                <dbl> 0.4, 0.6, 0.7, 0.7, 0.8, 1.0, 0.9, 1.0, 0.8,…
#> $ one_race_h                <dbl> 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1,…
#> $ one_race_other            <dbl> 0.7, 0.7, 0.6, 0.5, 0.4, 0.7, 0.7, 0.9, 1.4,…
#> $ two_races                 <dbl> 1.9, 1.4, 1.5, 1.5, 1.5, 1.4, 1.3, 1.6, 2.0,…
#> $ hispanic                  <dbl> 1.8, 2.0, 2.3, 2.4, 2.4, 2.5, 2.5, 2.6, 2.6,…
#> $ households                <dbl> 18373, 18288, 19718, 19998, 19934, 20071, 20…
#> $ h_under6_both_work        <dbl> 1543, 1475, 1569, 1695, 1714, 1532, 1557, 13…
#> $ h_under6_f_work           <dbl> 970, 964, 1009, 1060, 938, 880, 1191, 1258, …
#> $ h_under6_m_work           <dbl> 22, 16, 16, 106, 120, 161, 159, 211, 109, 10…
#> $ h_under6_single_m         <dbl> 995, 1099, 1110, 1030, 1095, 1160, 954, 883,…
#> $ h_6to17_both_work         <dbl> 4900, 5028, 5472, 5065, 4608, 4238, 4056, 40…
#> $ h_6to17_fwork             <dbl> 1308, 1519, 1541, 1965, 1963, 1978, 2073, 20…
#> $ h_6to17_mwork             <dbl> 114, 92, 113, 246, 284, 354, 373, 551, 322, …
#> $ h_6to17_single_m          <dbl> 1966, 2305, 2377, 2299, 2644, 2522, 2269, 21…
#> $ emp_m                     <dbl> 27.40, 29.54, 29.33, 31.17, 32.13, 31.74, 32…
#> $ memp_m                    <dbl> 24.41, 26.07, 25.94, 26.97, 28.59, 27.44, 28…
#> $ femp_m                    <dbl> 30.68, 33.40, 33.06, 35.96, 36.09, 36.61, 37…
#> $ emp_service               <dbl> 17.06, 15.81, 16.92, 16.18, 16.09, 16.72, 16…
#> $ memp_service              <dbl> 15.53, 14.16, 15.09, 14.21, 14.71, 13.92, 13…
#> $ femp_service              <dbl> 18.75, 17.64, 18.93, 18.42, 17.63, 19.89, 20…
#> $ emp_sales                 <dbl> 29.11, 28.75, 29.07, 27.56, 28.39, 27.22, 25…
#> $ memp_sales                <dbl> 15.97, 17.51, 17.82, 17.74, 17.79, 17.38, 15…
#> $ femp_sales                <dbl> 43.52, 41.25, 41.43, 38.76, 40.26, 38.36, 36…
#> $ emp_n                     <dbl> 13.21, 11.89, 11.57, 10.72, 9.02, 9.27, 9.38…
#> $ memp_n                    <dbl> 22.54, 20.30, 19.86, 18.28, 16.03, 16.79, 17…
#> $ femp_n                    <dbl> 2.99, 2.52, 2.45, 2.09, 1.19, 0.77, 0.58, 0.…
#> $ emp_p                     <dbl> 13.22, 14.02, 13.11, 14.38, 14.37, 15.04, 16…
#> $ memp_p                    <dbl> 21.55, 21.96, 21.28, 22.80, 22.88, 24.48, 24…
#> $ femp_p                    <dbl> 4.07, 5.19, 4.13, 4.77, 4.84, 4.36, 6.07, 7.…
#> $ mcsa                      <dbl> 80.92, 83.42, 85.92, 88.43, 90.93, 93.43, 95…

Even after omitting rows with missing values are we left with 23593 observations. That is plenty to work with! We are likely to get a reliable estimate of the model performance from a validation set without having to fit and evaluate the model multiple times, as with, for example, v-fold cross-validation.

We are creating a three-way split of the data into a training, a validation, and a test set with the new initial_validation_split() function. We are stratifying based on our outcome mcsa. The default of prop = c(0.6, 0.2) means that 60% of the data gets allocated to the training set and 20% to the validation set - and the remaining 20% go into the test set.

set.seed(123)
childcare_split <- childcare_costs |>
  initial_validation_split(strata = mcsa)
childcare_split
#> <Training/Validation/Testing/Total>
#> <14155/4718/4720/23593>

You can access the subsets of the data with the familiar training() and testing() as well as the new validation():

validation(childcare_split)
#> # A tibble: 4,718 × 53
#>    study_year unr_16 funr_16 munr_16 unr_20to64 funr_20to64 munr_20to64
#>         <dbl>  <dbl>   <dbl>   <dbl>      <dbl>       <dbl>       <dbl>
#>  1       2013   9.39   10.3     8.56        7.3         7.6         7  
#>  2       2011  13.0    12.4    13.6        13.2        12.4        13.9
#>  3       2008   3.85    4.4     3.43        3.7         3.9         3.6
#>  4       2015   8.31   11.8     5.69        7.8        11.7         4.9
#>  5       2015   7.67    6.92    8.27        7.6         6.7         8.3
#>  6       2016   5.95    6.33    5.66        5.7         5.9         5.5
#>  7       2009  10.7    15.9     7.06        8.7        16.8         2.9
#>  8       2010  11.2    15.2     7.89       10.9        14.7         7.8
#>  9       2013  15.0    17.0    13.4        15.2        18.1        13  
#> 10       2014  17.4    16.3    18.2        17.2        17.7        16.9
#> # ℹ 4,708 more rows
#> # ℹ 46 more variables: flfpr_20to64 <dbl>, flfpr_20to64_under6 <dbl>,
#> #   flfpr_20to64_6to17 <dbl>, flfpr_20to64_under6_6to17 <dbl>,
#> #   mlfpr_20to64 <dbl>, pr_f <dbl>, pr_p <dbl>, mhi_2018 <dbl>, me_2018 <dbl>,
#> #   fme_2018 <dbl>, mme_2018 <dbl>, total_pop <dbl>, one_race <dbl>,
#> #   one_race_w <dbl>, one_race_b <dbl>, one_race_i <dbl>, one_race_a <dbl>,
#> #   one_race_h <dbl>, one_race_other <dbl>, two_races <dbl>, hispanic <dbl>, …

You may want to extract the training data to do some exploratory data analysis but here we are going to rely on xgboost to figure out patterns in the data so we can breeze straight to tuning a model.

xgb_spec <-
  boost_tree(
    trees = 500,
    min_n = tune(),
    mtry = tune(),
    stop_iter = tune(),
    learn_rate = 0.01
  ) |>
  set_engine("xgboost", validation = 0.2) |>
  set_mode("regression")

xgb_wf <- workflow(mcsa ~ ., xgb_spec)
xgb_wf
#> ══ Workflow ════════════════════════════════════════════════════════════════════
#> Preprocessor: Formula
#> Model: boost_tree()
#> 
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> mcsa ~ .
#> 
#> ── Model ───────────────────────────────────────────────────────────────────────
#> Boosted Tree Model Specification (regression)
#> 
#> Main Arguments:
#>   mtry = tune()
#>   trees = 500
#>   min_n = tune()
#>   learn_rate = 0.01
#>   stop_iter = tune()
#> 
#> Engine-Specific Arguments:
#>   validation = 0.2
#> 
#> Computational engine: xgboost

We give this workflow object with the model specification to tune_grid() to try multiple combinations of the hyperparameters we tagged for tuning (min_n, mtry, and stop_iter).

During tuning, the model should not have access to the test data, only to the data used to fit the model (the analysis set) and the data used to assess the model (the assessment set). Each pair of analysis and assessment set forms a resample. For 10-fold cross-validation, we’d have 10 resamples. With a validation split, we have just one resample with the training set functioning as the analysis set and the validation set as the assessment set. The tidymodels tuning functions all expect a set of resamples (which can be of size one) and the corresponding objects are of class rset.

To remove the test data from the initial three-way split and create such an rset object for tuning, use validation_set().

set.seed(234)
childcare_set <- validation_set(childcare_split)
childcare_set
#> # A tibble: 1 × 2
#>   splits               id        
#>   <list>               <chr>     
#> 1 <split [14155/4718]> validation

We are going to try 15 different parameter combinations and pick the one with the smallest RMSE.

set.seed(234)
xgb_res <- tune_grid(xgb_wf, childcare_set, grid = 15)
#> i Creating pre-processing data to finalize unknown parameter: mtry
#> Warning in `[.tbl_df`(x, is.finite(x <- as.numeric(x))): NAs introduced by coercion
best_parameters <- select_best(xgb_res, "rmse")
childcare_wflow <- finalize_workflow(xgb_wf, best_parameters)

last_fit() then lets you fit your model on the training data and calculate performance on the test data. If you provide it with a three-way split, you can choose if you want your model to be fitted on the training data only or on the combination of training and validation set. You can specify this with the add_validation_set argument.

childcare_fit <- last_fit(childcare_wflow, childcare_split, add_validation_set = TRUE)
collect_metrics(childcare_fit)
#> # A tibble: 2 × 4
#>   .metric .estimator .estimate .config             
#>   <chr>   <chr>          <dbl> <chr>               
#> 1 rmse    standard      21.4   Preprocessor1_Model1
#> 2 rsq     standard       0.610 Preprocessor1_Model1

This takes you through the important changes for validation sets in the tidymodels framework!

Acknowledgements

Many thanks to the people who contributed since the last releases!

For rsample: @afrogri37, @AngelFelizR, @bschneidr, @erictleung, @exsell-jc, @hfrick, @jrosell, @MasterLuke84, @MichaelChirico, @mikemahoney218, @rdavis120, @sametsoekel, @Shafi2016, @simonpcouch, @topepo, and @trevorcampbell.

For tune: @blechturm, @cphaarmeyer, @EmilHvitfeldt, @forecastingEDs, @hfrick, @kjbeath, @mikemahoney218, @rdavis120, @simonpcouch, and @topepo.