Cross validation

Prof. Maria Tackett

Oct 24, 2022

Announcements

  • Lab 05

    • due TODAY, 11:59pm (Thu labs)

    • due Tuesday, 11:59pm (Fri labs)

  • Office hours update:

    • Monday, 1 - 2pm: in-person only (Old Chem 118B)
  • Click here for explanation about sum of squares in R ANOVA output.

  • See Week 09 activities.

Spring 2023 Statistics classes

  • STA 211: Mathematics of Regression

    • Pre-req: STA 210 + Math 216/218/221
  • STA 240: Probability for Statistical Inference, Modeling, and Data Analysis

    • Pre-req: Calc 2
  • STA 313: Advanced Data Visualization

    • Pre-req: STA 199 or STA 210
  • STA 323: Statistical Computing

    • Pre-req: STA 210 and STA 230 / 240
  • STA 360: Bayesian Inference and Modern Statistical Methods

    • Pre-req: STA 210, STA 230/240, CS 101, Calc 2, Math 216/218/221,
    • Co-req: STA 211

Topics

  • Cross validation for model evaluation
  • Cross validation for model comparison

Computational setup

# load packages
library(tidyverse)
library(tidymodels)
library(knitr)
library(schrute)

Data & goal

  • Data: The data come from the shrute package, and has been transformed using instructions from Lab 04
  • Goal: Predict imdb_rating from other variables in the dataset
office_episodes <- read_csv(here::here("slides/data/office_episodes.csv"))
office_episodes
# A tibble: 186 × 14
   season episode episode_n…¹ imdb_…² total…³ air_date   lines…⁴ lines…⁵ lines…⁶
    <dbl>   <dbl> <chr>         <dbl>   <dbl> <date>       <dbl>   <dbl>   <dbl>
 1      1       1 Pilot           7.6    3706 2005-03-24  0.157   0.179    0.354
 2      1       2 Diversity …     8.3    3566 2005-03-29  0.123   0.0591   0.369
 3      1       3 Health Care     7.9    2983 2005-04-05  0.172   0.131    0.230
 4      1       4 The Allian…     8.1    2886 2005-04-12  0.202   0.0905   0.280
 5      1       5 Basketball      8.4    3179 2005-04-19  0.0913  0.0609   0.452
 6      1       6 Hot Girl        7.8    2852 2005-04-26  0.159   0.130    0.306
 7      2       1 The Dundies     8.7    3213 2005-09-20  0.125   0.160    0.375
 8      2       2 Sexual Har…     8.2    2736 2005-09-27  0.0565  0.0954   0.353
 9      2       3 Office Oly…     8.4    2742 2005-10-04  0.196   0.117    0.295
10      2       4 The Fire        8.4    2713 2005-10-11  0.160   0.0690   0.216
# … with 176 more rows, 5 more variables: lines_dwight <dbl>, halloween <chr>,
#   valentine <chr>, christmas <chr>, michael <chr>, and abbreviated variable
#   names ¹​episode_name, ²​imdb_rating, ³​total_votes, ⁴​lines_jim, ⁵​lines_pam,
#   ⁶​lines_michael

Modeling prep

Split data into training and testing

set.seed(123)
office_split <- initial_split(office_episodes)
office_train <- training(office_split)
office_test <- testing(office_split)

Specify model

office_spec <- linear_reg() |>
  set_engine("lm")

office_spec
Linear Regression Model Specification (regression)

Computational engine: lm 

Model 1

From Lab 04

Create a recipe that uses the newly generated variables

  • Denotes episode_name as an ID variable and doesn’t use air_date or season as predictors
  • Create dummy variables for all nominal predictors
  • Remove all zero variance predictors

Create recipe

office_rec1 <- recipe(imdb_rating ~ ., data = office_train) |>
  update_role(episode_name, new_role = "id") |>
  step_rm(air_date, season) |>
  step_dummy(all_nominal_predictors()) |>
  step_zv(all_predictors())

office_rec1
Recipe

Inputs:

      role #variables
        id          1
   outcome          1
 predictor         12

Operations:

Variables removed air_date, season
Dummy variables from all_nominal_predictors()
Zero variance filter on all_predictors()

Preview recipe

prep(office_rec1) |>
  bake(office_train) |>
  glimpse()
Rows: 139
Columns: 12
$ episode       <dbl> 20, 16, 8, 7, 23, 3, 16, 21, 18, 14, 27, 28, 12, 1, 23, …
$ episode_name  <fct> "Welcome Party", "Moving On", "Performance Review", "The…
$ total_votes   <dbl> 1489, 1572, 2416, 1406, 2783, 1802, 2283, 2041, 1445, 14…
$ lines_jim     <dbl> 0.12703583, 0.05588822, 0.09523810, 0.07482993, 0.078291…
$ lines_pam     <dbl> 0.10423453, 0.10978044, 0.10989011, 0.15306122, 0.081850…
$ lines_michael <dbl> 0.0000000, 0.0000000, 0.3772894, 0.0000000, 0.3736655, 0…
$ lines_dwight  <dbl> 0.07166124, 0.08782435, 0.15384615, 0.18027211, 0.135231…
$ imdb_rating   <dbl> 7.2, 8.2, 8.2, 7.7, 9.1, 8.2, 8.3, 8.9, 8.0, 7.8, 8.7, 8…
$ halloween_yes <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
$ valentine_yes <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
$ christmas_yes <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0,…
$ michael_yes   <dbl> 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1,…

Create workflow

office_wflow1 <- workflow() |>
  add_model(office_spec) |>
  add_recipe(office_rec1)

office_wflow1
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: linear_reg()

── Preprocessor ────────────────────────────────────────────────────────────────
3 Recipe Steps

• step_rm()
• step_dummy()
• step_zv()

── Model ───────────────────────────────────────────────────────────────────────
Linear Regression Model Specification (regression)

Computational engine: lm 

Fit model to training data

Not so fast!

Cross validation

Spending our data

  • We have already established that the idea of data spending where the test set was recommended for obtaining an unbiased estimate of performance.
  • However, we usually need to understand the effectiveness of the model before using the test set.
  • Typically we can’t decide on which final model to take to the test set without making model assessments.
  • Remedy: Resampling to make model assessments on training data in a way that can generalize to new data.

Resampling for model assessment

Resampling is only conducted on the training set. The test set is not involved. For each iteration of resampling, the data are partitioned into two subsamples:

  • The model is fit with the analysis set. Model fit statistics such as \(R^2_{Adj}\), AIC, and BIC are calculated based on this fit.
  • The model is evaluated with the assessment set.

Resampling for model assessment


Source: Kuhn and Silge. Tidy modeling with R.

Analysis and assessment sets

  • Analysis set is analogous to training set.
  • Assessment set is analogous to test set.
  • The terms analysis and assessment avoids confusion with initial split of the data.
  • These data sets are mutually exclusive.

Cross validation

More specifically, v-fold cross validation – commonly used resampling technique:

  • Randomly split your training data into v partitions
  • Use v-1 partitions for analysis, and the remaining 1 partition for analysis (model fit + model fit statistics)
  • Repeat v times, updating which partition is used for assessment each time

Let’s give an example where v = 3

Cross validation, step 1

Randomly split your training data into 3 partitions:


Split data

set.seed(345)
folds <- vfold_cv(office_train, v = 3)
folds
#  3-fold cross-validation 
# A tibble: 3 × 2
  splits          id   
  <list>          <chr>
1 <split [92/47]> Fold1
2 <split [93/46]> Fold2
3 <split [93/46]> Fold3

Cross validation, steps 2 and 3

  • Use v-1 partitions for analysis, and the remaining 1 partition for assessment
  • Repeat v times, updating which partition is used for assessment each time

Fit resamples

# Function to get Adj R-sq, AIC, BIC
calc_model_stats <- function(x) {
  glance(extract_fit_parsnip(x)) |>
    select(adj.r.squared, AIC, BIC)
}

set.seed(456)

# Fit model and calculate statistics for each fold
office_fit_rs1 <- office_wflow1 |>
  fit_resamples(resamples = folds, 
                control = control_resamples(extract = calc_model_stats))

office_fit_rs1
# Resampling results
# 3-fold cross-validation 
# A tibble: 3 × 5
  splits          id    .metrics         .notes           .extracts       
  <list>          <chr> <list>           <list>           <list>          
1 <split [92/47]> Fold1 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
2 <split [93/46]> Fold2 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
3 <split [93/46]> Fold3 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>

Cross validation, now what?

  • We’ve fit a bunch of models
  • Now it’s time to use them to collect metrics (e.g., $R^2$, AIC, RMSE, etc. ) on each model and use them to evaluate model fit and how it varies across folds

Collect \(R^2\) and RMSE from CV

# Produces summary across all CV
collect_metrics(office_fit_rs1)
# A tibble: 2 × 6
  .metric .estimator  mean     n std_err .config             
  <chr>   <chr>      <dbl> <int>   <dbl> <chr>               
1 rmse    standard   0.353     3  0.0117 Preprocessor1_Model1
2 rsq     standard   0.539     3  0.0378 Preprocessor1_Model1


Note: These are calculated using the assessment data

Deeper look into \(R^2\) and RMSE

cv_metrics1 <- collect_metrics(office_fit_rs1, summarize = FALSE) 

cv_metrics1
# A tibble: 6 × 5
  id    .metric .estimator .estimate .config             
  <chr> <chr>   <chr>          <dbl> <chr>               
1 Fold1 rmse    standard       0.355 Preprocessor1_Model1
2 Fold1 rsq     standard       0.525 Preprocessor1_Model1
3 Fold2 rmse    standard       0.373 Preprocessor1_Model1
4 Fold2 rsq     standard       0.481 Preprocessor1_Model1
5 Fold3 rmse    standard       0.332 Preprocessor1_Model1
6 Fold3 rsq     standard       0.610 Preprocessor1_Model1

Better tabulation of \(R^2\) and RMSE from CV

cv_metrics1 |>
  mutate(.estimate = round(.estimate, 3)) |>
  pivot_wider(id_cols = id, names_from = .metric, values_from = .estimate) |>
  kable(col.names = c("Fold", "RMSE", "R-squared"))
Fold RMSE R-squared
Fold1 0.355 0.525
Fold2 0.373 0.481
Fold3 0.332 0.610

How does RMSE compare to y?

Cross validation RMSE stats:

cv_metrics1 |>
  filter(.metric == "rmse") |>
  summarise(
    min = min(.estimate),
    max = max(.estimate),
    mean = mean(.estimate),
    sd = sd(.estimate)
  )
# A tibble: 1 × 4
    min   max  mean     sd
  <dbl> <dbl> <dbl>  <dbl>
1 0.332 0.373 0.353 0.0202

Training data IMDB score stats:

office_episodes |>
  
  summarise(
    min = min(imdb_rating),
    max = max(imdb_rating),
    mean = mean(imdb_rating),
    sd = sd(imdb_rating)
  )
# A tibble: 1 × 4
    min   max  mean    sd
  <dbl> <dbl> <dbl> <dbl>
1   6.7   9.7  8.25 0.535

Collect \(R^2_{Adj}\), AIC, BIC from CV

map_df(office_fit_rs1$.extracts, ~ .x[[1]][[1]]) |>
  bind_cols(Fold = office_fit_rs1$id)
# A tibble: 3 × 4
  adj.r.squared   AIC   BIC Fold 
          <dbl> <dbl> <dbl> <chr>
1         0.585  70.3 101.  Fold1
2         0.615  63.0  93.4 Fold2
3         0.553  77.6 108.  Fold3


Note: These are based on the model fit from the analysis data

Cross validation jargon

  • Referred to as v-fold or k-fold cross validation
  • Also commonly abbreviated as CV

Cross validation in practice

  • To illustrate how CV works, we used v = 3:

    • Analysis sets are 2/3 of the training set
    • Each assessment set is a distinct 1/3
    • The final resampling estimate of performance averages each of the 3 replicates
  • This was useful for illustrative purposes, but v = 3 is a poor choice in practice

  • Values of v are most often 5 or 10; we generally prefer 10-fold cross-validation as a default

Application exercise

Recap

  • Cross validation for model evaluation
  • Cross validation for model comparison