Cross validation

STA 210 - Summer 2022

Author

Yunran Chen

Welcome

Annoucement

  • Cross validation for model evaluation
  • Cross validation for model comparison

Computational setup

# load packages
library(tidyverse)
library(tidymodels)
library(knitr)
library(schrute)

Data & goal

  • Data: The data come from the shrute package, and has been transformed using instructions from Lab 4(released on Wednesday)
  • Goal: Predict imdb_rating from other variables in the dataset
office_episodes <- read_csv(here::here("slides", "data/office_episodes.csv"))
office_episodes
# A tibble: 186 × 14
   season episode episode_name      imdb_rating total_votes air_date   lines_jim
    <dbl>   <dbl> <chr>                   <dbl>       <dbl> <date>         <dbl>
 1      1       1 Pilot                     7.6        3706 2005-03-24    0.157 
 2      1       2 Diversity Day             8.3        3566 2005-03-29    0.123 
 3      1       3 Health Care               7.9        2983 2005-04-05    0.172 
 4      1       4 The Alliance              8.1        2886 2005-04-12    0.202 
 5      1       5 Basketball                8.4        3179 2005-04-19    0.0913
 6      1       6 Hot Girl                  7.8        2852 2005-04-26    0.159 
 7      2       1 The Dundies               8.7        3213 2005-09-20    0.125 
 8      2       2 Sexual Harassment         8.2        2736 2005-09-27    0.0565
 9      2       3 Office Olympics           8.4        2742 2005-10-04    0.196 
10      2       4 The Fire                  8.4        2713 2005-10-11    0.160 
# … with 176 more rows, and 7 more variables: lines_pam <dbl>,
#   lines_michael <dbl>, lines_dwight <dbl>, halloween <chr>, valentine <chr>,
#   christmas <chr>, michael <chr>

Review on workflow of building a model

Review on workflow of building a model

  • Spending data: Split data into training and test sets
  • Specify question: association between y and x(s)
  • Feature engineering
  • model fitting, condition check and evaluation
  • model comparison
  • make inference (HT + CI)
  • model prediction
  • make conclusion

Modeling prep

Split data into training and testing

set.seed(123)
office_split <- initial_split(office_episodes)
office_train <- training(office_split)
office_test <- testing(office_split)

Specify model

office_spec <- linear_reg() %>%
  set_engine("lm")

office_spec
Linear Regression Model Specification (regression)

Computational engine: lm 

Model 1

One possible recipe

  • Create a recipe that uses the new variables we generated
  • Denotes episode_name as an ID variable and doesn’t use air_date as a predictor
  • Create dummy variables for all nominal predictors
  • Remove all zero variance predictors

Create recipe

office_rec1 <- recipe(imdb_rating ~ ., data = office_train) %>%
  update_role(episode_name, new_role = "id") %>%
  step_rm(air_date) %>%
  step_dummy(all_nominal_predictors()) %>%
  step_zv(all_predictors())

office_rec1
Recipe

Inputs:

      role #variables
        id          1
   outcome          1
 predictor         12

Operations:

Variables removed air_date
Dummy variables from all_nominal_predictors()
Zero variance filter on all_predictors()

Create workflow

office_wflow1 <- workflow() %>%
  add_model(office_spec) %>%
  add_recipe(office_rec1)

office_wflow1
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: linear_reg()

── Preprocessor ────────────────────────────────────────────────────────────────
3 Recipe Steps

• step_rm()
• step_dummy()
• step_zv()

── Model ───────────────────────────────────────────────────────────────────────
Linear Regression Model Specification (regression)

Computational engine: lm 

Build model

  • Fit model to training data
  • Make predictions on testing data
  • Evaluate model

Data Splitting is random ! May vary from time to time.

Try multiple times and take average to make the conclusion robust.

Cross validation

Spending our data

  • Idea of data spending: test set was recommended for performance evaluation.
  • Training data (Model fitting) + Test data (Model prediction)
  • before using the test set: assure effectiveness of the model
  • How to decide on which final model to take to the test set
  • Treat training set as your dataset do data splitting. Repeat and take average.

Resampling for model assessment

Resampling is only conducted on the training set. The test set is not involved. For each iteration of resampling, the data are partitioned into two subsamples:

  • The model is fit with the analysis set (training).
  • The model is evaluated with the assessment set(test).

Resampling for model assessment


Source: Kuhn and Silge. Tidy modeling with R.

Analysis and assessment sets

  • Analysis set is analogous to training set.
  • Assessment set is analogous to test set.
  • The terms analysis and assessment avoids confusion with initial split of the data.
  • These data sets are mutually exclusive.

Cross validation

More specifically, v-fold cross validation – commonly used resampling technique:

  • Randomly split your training data into v partitions
  • Use 1 partition for assessment, and the remaining v-1 partitions for analysis
  • Repeat v times, updating which partition is used for assessment each time

Let’s give an example where v = 3

Cross validation, step 1

Randomly split your training data into 3 partitions:


Split data

set.seed(345)
folds <- vfold_cv(office_train, v = 3)
folds
#  3-fold cross-validation 
# A tibble: 3 × 2
  splits          id   
  <list>          <chr>
1 <split [92/47]> Fold1
2 <split [93/46]> Fold2
3 <split [93/46]> Fold3

Cross validation, steps 2 and 3

  • Use 1 partition for assessment, and the remaining v-1 partitions for analysis
  • Repeat v times, updating which partition is used for assessment each time

Fit resamples

set.seed(456)

office_fit_rs1 <- office_wflow1 %>%
  fit_resamples(folds)

office_fit_rs1
# Resampling results
# 3-fold cross-validation 
# A tibble: 3 × 4
  splits          id    .metrics         .notes          
  <list>          <chr> <list>           <list>          
1 <split [92/47]> Fold1 <tibble [2 × 4]> <tibble [0 × 3]>
2 <split [93/46]> Fold2 <tibble [2 × 4]> <tibble [0 × 3]>
3 <split [93/46]> Fold3 <tibble [2 × 4]> <tibble [0 × 3]>

Cross validation, now what?

  • We’ve fit a bunch of models
  • Now it’s time to use them to collect metrics (e.g., R-squared, RMSE) on each model and use them to evaluate model fit and how it varies across folds

Collect CV metrics

collect_metrics(office_fit_rs1)
# A tibble: 2 × 6
  .metric .estimator  mean     n std_err .config             
  <chr>   <chr>      <dbl> <int>   <dbl> <chr>               
1 rmse    standard   0.351     3  0.0111 Preprocessor1_Model1
2 rsq     standard   0.546     3  0.0378 Preprocessor1_Model1

Deeper look into CV metrics

cv_metrics1 <- collect_metrics(office_fit_rs1, summarize = FALSE) 

cv_metrics1
# A tibble: 6 × 5
  id    .metric .estimator .estimate .config             
  <chr> <chr>   <chr>          <dbl> <chr>               
1 Fold1 rmse    standard       0.356 Preprocessor1_Model1
2 Fold1 rsq     standard       0.520 Preprocessor1_Model1
3 Fold2 rmse    standard       0.367 Preprocessor1_Model1
4 Fold2 rsq     standard       0.498 Preprocessor1_Model1
5 Fold3 rmse    standard       0.330 Preprocessor1_Model1
6 Fold3 rsq     standard       0.621 Preprocessor1_Model1

Better tabulation of CV metrics

cv_metrics1 %>%
  mutate(.estimate = round(.estimate, 3)) %>%
  pivot_wider(id_cols = id, names_from = .metric, values_from = .estimate) %>%
  kable(col.names = c("Fold", "RMSE", "R-squared"))
Fold RMSE R-squared
Fold1 0.356 0.520
Fold2 0.367 0.498
Fold3 0.330 0.621

How does RMSE compare to y?

Cross validation RMSE stats:

cv_metrics1 %>%
  filter(.metric == "rmse") %>%
  summarise(
    min = min(.estimate),
    max = max(.estimate),
    mean = mean(.estimate),
    sd = sd(.estimate)
  )
# A tibble: 1 × 4
    min   max  mean     sd
  <dbl> <dbl> <dbl>  <dbl>
1 0.330 0.367 0.351 0.0192

Training data IMDB score stats:

office_episodes %>%
  summarise(
    min = min(imdb_rating),
    max = max(imdb_rating),
    mean = mean(imdb_rating),
    sd = sd(imdb_rating)
  )
# A tibble: 1 × 4
    min   max  mean    sd
  <dbl> <dbl> <dbl> <dbl>
1   6.7   9.7  8.25 0.535

Cross validation jargon

  • Referred to as v-fold or k-fold cross validation
  • Also commonly abbreviated as CV

Cross validation, for reals

  • To illustrate how CV works, we used v = 3:

    • Analysis sets are 2/3 of the training set
    • Each assessment set is a distinct 1/3
    • The final resampling estimate of performance averages each of the 3 replicates
  • This was useful for illustrative purposes, but v = 3 is a poor choice in practice

  • Values of v are most often 5 or 10; we generally prefer 10-fold cross-validation as a default

Application exercise

Recap

  • Cross validation for model evaluation
  • Cross validation for model comparison