Cross-validation

Split data into training and test

library(ISLR2)
set.seed (1)
train <- sample (392 , 196)
train
##   [1] 324 167 129 299 270 187 307  85 277 362 330 263 329  79 213  37 105 217
##  [19] 366 165 290 383  89 289 340 326 382  42 111  20  44 343  70 121  40 172
##  [37]  25 248 198  39 298 280 160  14 130  45  22 206 230 193 104 367 255 341
##  [55] 342 103 331  13 296 375 176 279 110  84  29 141 252 221 108 304  33 347
##  [73] 149 287 102 145 118 323 107  64 224 337  51 325 372 138 390 389 282 143
##  [91] 285 170  48 204 295  24 181 214 225 163  43   1 328  78 284 116 233  61
## [109]  86 374  49 242 246 247 239 219 135 364 363 310  53 348  65 376 124  77
## [127] 218  98 194  19  31 174 237  75  16 358   9  50  92 122 152 386 207 244
## [145] 229 350 355 391 223 373 309 140 126 349 344 319 258  15 271 388 195 201
## [163] 318  17 212 127 133  41 384 392 159 117  72  36 315 294 157 378 313 306
## [181] 272 106 185  88 281 228 238 368  80  30  93 234 220 240 369 164

Data

data(Auto)
head(Auto)
##   mpg cylinders displacement horsepower weight acceleration year origin
## 1  18         8          307        130   3504         12.0   70      1
## 2  15         8          350        165   3693         11.5   70      1
## 3  18         8          318        150   3436         11.0   70      1
## 4  16         8          304        150   3433         12.0   70      1
## 5  17         8          302        140   3449         10.5   70      1
## 6  15         8          429        198   4341         10.0   70      1
##                        name
## 1 chevrolet chevelle malibu
## 2         buick skylark 320
## 3        plymouth satellite
## 4             amc rebel sst
## 5               ford torino
## 6          ford galaxie 500
attach(Auto)

Fit a regression model

lm.fit <- lm(mpg ~ horsepower , data = Auto, subset = train)

Compute error on test set

## Make predtictions
y_hat <- predict(lm.fit, Auto)
y_hat
##         1         2         3         4         5         6         7         8 
## 19.227919 13.289865 15.834746 15.834746 17.531332  7.691129  3.958638  4.806931 
##         9        10        11        12        13        14        15        16 
##  3.110344  9.048398 12.441572 14.138159 15.834746  3.110344 25.165973 25.165973 
##        17        18        19        20        21        22        23        24 
## 24.826656 26.862560 26.353584 33.479249 26.523243 26.014267 25.165973 22.112117 
##        25        26        27        28        29        30        31        32 
## 26.014267  4.806931  7.351811  5.655225  8.539422 26.353584 26.014267 25.165973 
##        34        35        36        37        38        39        40        41 
## 24.317680 23.469386 24.317680 26.353584 24.317680 13.289865 11.593279 15.325770 
##        42        43        44        45        46        47        48        49 
## 15.834746 10.744985 12.441572 11.593279 22.621093 29.068123 24.317680 26.353584 
##        50        51        52        53        54        55        56        57 
## 26.692901 26.014267 29.407440 28.389488 30.255734 29.577099 31.104027 29.407440 
##        58        59        60        61        62        63        64        65 
## 25.165973 27.710854 32.121979 26.014267 26.692901 13.289865 11.593279 15.834746 
##        66        67        68        69        70        71        72        73 
## 15.325770 15.834746  5.994542 14.986452 14.138159  9.048398 24.826656 15.834746 
##        74        75        76        77        78        79        80        81 
## 19.227919 17.531332 15.834746 22.281776 28.389488 26.523243 29.577099 26.692901 
##        82        83        84        85        86        87        88        89 
## 25.674949 24.826656 27.710854 26.353584 11.593279 15.834746 16.683039 18.040309 
##        90        91        92        93        94        95        96        97 
## 15.834746  7.691129 15.834746 14.477476 15.834746  4.806931  3.110344 11.593279 
##        98        99       100       101       102       103       104       105 
## 23.469386 24.317680 24.317680 26.353584 25.165973 33.479249 15.834746 12.950548 
##       106       107       108       109       110       111       112       113 
## 12.441572 10.744985 24.317680 26.353584 29.068123 25.335632 26.014267 26.862560 
##       114       115       116       117       118       119       120       121 
## 23.130069 26.014267 16.683039  2.262051 32.970273 28.559147 25.844608 22.281776 
##       122       123       124       125       126       128       129       130 
## 15.834746 22.621093 20.585189 10.744985 25.165973 24.317680 24.317680 29.916416 
##       131       132       133       134       135       136       137       138 
## 27.710854 30.255734 28.559147 24.317680 22.621093 23.469386 17.531332 15.834746 
##       139       140       141       142       143       144       145       146 
## 15.834746 17.531332 15.834746 27.201877 29.916416 28.050171 32.461297 30.934369 
##       147       148       149       150       151       152       153       154 
## 28.559147 28.559147 28.559147 24.826656 25.505291 29.916416 25.165973 23.469386 
##       155       156       157       158       159       160       161       162 
## 29.068123 29.068123 12.441572 16.683039 15.834746 16.174063 22.621093 23.469386 
##       163       164       165       166       167       168       169       170 
## 22.621093 25.165973 22.621093 22.621093 19.397578 28.559147 27.201877 24.317680 
##       171       172       173       174       175       176       177       178 
## 28.050171 24.996315 29.237782 24.826656 24.826656 29.407440 26.014267 25.165973 
##       179       180       181       182       183       184       185       186 
## 26.353584 24.656997 21.772800 32.291638 26.692901 27.541195 25.674949 27.880512 
##       187       188       189       190       191       192       193       194 
## 27.201877 17.531332 15.834746 20.924506 15.495428 24.317680 23.469386 27.541195 
##       195       196       197       198       199       200       201       202 
## 26.014267 32.461297 31.104027 29.407440 32.291638 24.317680 28.050171 22.621093 
##       203       204       205       206       207       208       209       210 
## 25.165973 29.237782 29.407440 28.559147 29.068123 23.978362 15.834746 26.353584 
##       211       212       213       214       215       216       217       218 
## 22.960410 20.924506 10.744985 16.683039 19.227919 15.834746 29.746758 27.710854 
##       219       220       221       222       223       224       225       226 
## 31.443345 24.996315 29.407440 16.683039 22.621093 16.683039 19.227919 22.621093 
##       227       228       229       230       231       232       233       234 
## 23.469386 24.317680 24.656997 10.744985 12.441572  9.048398 16.004404 28.050171 
##       235       236       237       238       239       240       241       242 
## 26.353584 28.559147 26.183925 30.595051 27.201877 29.916416 28.050171 24.826656 
##       243       244       245       246       247       248       249       250 
## 22.621093 22.621093 33.139931 30.086075 32.461297 29.407440 31.104027 22.621093 
##       251       252       253       254       255       256       257       258 
## 17.531332 17.700991 23.469386 25.165973 26.862560 26.353584 24.317680 26.014267 
##       259       260       261       262       263       264       265       266 
## 23.469386 26.862560 22.621093 20.924506 16.683039 13.289865 17.700991 17.531332 
##       267       268       269       270       271       272       273       274 
## 29.746758 25.165973 24.826656 28.559147 25.165973 23.469386 26.862560 24.826656 
##       275       276       277       278       279       280       281       282 
## 23.808704 20.076213 21.772800 18.718943 29.237782 29.746758 21.772800 26.862560 
##       283       284       285       286       287       288       289       290 
## 26.353584 26.014267 22.621093 19.227919 19.397578 17.870650 18.379626 14.986452 
##       291       292       293       294       295       296       297       298 
## 17.192015 20.076213 15.834746 29.237782 30.255734 27.710854 27.710854 28.219830 
##       299       300       301       302       303       304       305       306 
## 20.076213 29.237782 26.014267 29.407440 29.407440 30.255734 29.577099 26.014267 
##       307       308       309       310       311       312       313       314 
## 21.772800 21.772800 26.014267 28.389488 31.104027 29.407440 30.255734 26.014267 
##       315       316       317       318       319       320       321       322 
## 26.353584 26.014267 26.014267 28.050171 26.014267 28.559147 25.674949 28.559147 
##       323       324       325       326       327       328       329       330 
## 30.255734 23.469386 30.255734 33.139931 33.139931 29.916416 29.916416 29.916416 
##       332       333       334       335       336       338       339       340 
## 29.916416 30.764710 18.888602 24.317680 26.353584 29.068123 27.032219 27.032219 
##       341       342       343       344       345       346       347       348 
## 25.674949 22.621093 27.032219 31.443345 30.425392 31.104027 29.916416 30.255734 
##       349       350       351       352       353       354       356       357 
## 30.764710 29.746758 30.595051 30.255734 30.255734 28.728806 28.559147 28.559147 
##       358       359       360       361       362       363       364       365 
## 24.317680 28.728806 27.710854 28.389488 21.603141 20.924506 22.621093 23.469386 
##       366       367       368       369       370       371       372       373 
## 26.353584 26.862560 26.353584 26.353584 26.353584 26.862560 27.032219 26.014267 
##       374       375       376       377       378       379       380       381 
## 25.674949 28.728806 29.746758 29.746758 30.595051 29.407440 26.353584 28.559147 
##       382       383       384       385       386       387       388       389 
## 29.407440 29.916416 29.916416 29.916416 22.621093 26.862560 25.674949 22.281776 
##       390       391       392       393       394       395       396       397 
## 24.996315 27.032219 26.014267 26.692901 32.461297 27.032219 27.880512 27.371536
##MSE on test
## Method 1
mean (( mpg - y_hat)[-train ] ^2)
## [1] 23.26601
## Method 2
mean (( mpg - predict(lm.fit , Auto))[-train ]^2)
## [1] 23.26601

Polynomial regression

lm.fit2 <- lm(mpg ~ poly(horsepower , 2), data = Auto , subset = train)
mean (( mpg - predict(lm.fit2 , Auto))[-train ]^2)
## [1] 18.71646
lm.fit3 <- lm(mpg ~ poly(horsepower , 3), data = Auto ,
subset = train)
mean (( mpg - predict(lm.fit3 , Auto))[-train ]^2)
## [1] 18.79401

Perform the same calculations on a different training and test set

set.seed (2)
train <- sample (392 , 196)
lm.fit <- lm(mpg ~ horsepower , subset = train)
mean (( mpg - predict(lm.fit , Auto))[-train ]^2)
## [1] 25.72651
lm.fit2 <- lm(mpg ~ poly(horsepower , 2), data = Auto, subset = train)
mean (( mpg - predict(lm.fit2 , Auto))[-train ]^2)
## [1] 20.43036
lm.fit3 <- lm(mpg ~ poly(horsepower , 3), data = Auto, subset = train)
mean (( mpg - predict(lm.fit3 , Auto))[-train ]^2)
## [1] 20.38533

Leave-One-Out Cross-Validation

The LOOCV estimate can be computed for any generalized linear model using the glm() and cv.glm() functions.

## Method 1
lm.fit <- lm(mpg ~ horsepower , data = Auto)
coef(lm.fit)
## (Intercept)  horsepower 
##  39.9358610  -0.1578447
## Method 2 - without passing family
glm.fit <- glm(mpg ~ horsepower , data = Auto)
coef(glm.fit)
## (Intercept)  horsepower 
##  39.9358610  -0.1578447

LOOCV

library(boot)
glm.fit <- glm(mpg ~ horsepower , data = Auto)
cv.err <- cv.glm(Auto , glm.fit)
cv.err$delta
## [1] 24.23151 24.23114

Repeat the process for polynomials

cv.error <- rep(0, 10)
for (i in 1:10) {
glm.fit <- glm(mpg ~ poly(horsepower , i), data = Auto)
cv.error[i] <- cv.glm(Auto , glm.fit)$delta [1] }
cv.error
##  [1] 24.23151 19.24821 19.33498 19.42443 19.03321 18.97864 18.83305 18.96115
##  [9] 19.06863 19.49093

k-Fold Cross-Validation

set.seed (17)
cv.error.10 <- rep(0, 10)
for (i in 1:10) {
 glm.fit <- glm(mpg ~ poly(horsepower , i), data = Auto)
 cv.error.10[i] <- cv.glm(Auto , glm.fit , K = 10)$delta [1]}
cv.error.10
##  [1] 24.27207 19.26909 19.34805 19.29496 19.03198 18.89781 19.12061 19.14666
##  [9] 18.87013 20.95520

Cross-validation using caret package

library(caret)
## Loading required package: ggplot2
## 
## Attaching package: 'ggplot2'
## The following object is masked from 'Auto':
## 
##     mpg
## Loading required package: lattice
## 
## Attaching package: 'lattice'
## The following object is masked from 'package:boot':
## 
##     melanoma
data(mtcars)
set.seed(47)

model <- train(mpg~ hp, mtcars, method = "lm",
trControl = trainControl(method = "cv", number = 10, verboseIter = TRUE)
) # number - number of folds
## + Fold01: intercept=TRUE 
## - Fold01: intercept=TRUE 
## + Fold02: intercept=TRUE 
## - Fold02: intercept=TRUE 
## + Fold03: intercept=TRUE 
## - Fold03: intercept=TRUE 
## + Fold04: intercept=TRUE 
## - Fold04: intercept=TRUE 
## + Fold05: intercept=TRUE 
## - Fold05: intercept=TRUE 
## + Fold06: intercept=TRUE 
## - Fold06: intercept=TRUE 
## + Fold07: intercept=TRUE 
## - Fold07: intercept=TRUE 
## + Fold08: intercept=TRUE 
## - Fold08: intercept=TRUE 
## + Fold09: intercept=TRUE 
## - Fold09: intercept=TRUE 
## + Fold10: intercept=TRUE 
## - Fold10: intercept=TRUE 
## Aggregating results
## Fitting final model on full training set
model
## Linear Regression 
## 
## 32 samples
##  1 predictor
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 29, 28, 29, 28, 28, 30, ... 
## Resampling results:
## 
##   RMSE     Rsquared   MAE     
##   3.92794  0.8368878  3.415551
## 
## Tuning parameter 'intercept' was held constant at a value of TRUE

Random forest with caret package

Hyperparameters

mtry: Number of variable is randomly collected to be sampled at each split time.

ntree: Number of branches will grow after each time split.

library(randomForest)
## randomForest 4.7-1.1
## Type rfNews() to see new features/changes/bug fixes.
## 
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
## 
##     margin
library(caret)


# Random Search
control <- trainControl(method="repeatedcv", number=10, repeats=3, search="random")
set.seed(2)

rf_random <- train(mpg~., data=mtcars, method="rf",  tuneLength=15, trControl=control)
print(rf_random)
## Random Forest 
## 
## 32 samples
## 10 predictors
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 3 times) 
## Summary of sample sizes: 28, 29, 28, 30, 28, 29, ... 
## Resampling results across tuning parameters:
## 
##   mtry  RMSE      Rsquared   MAE     
##   1     2.486211  0.9358560  2.170779
##   2     2.284862  0.9368179  2.007097
##   3     2.187498  0.9473989  1.914519
##   6     2.121223  0.9450167  1.856398
##   7     2.148674  0.9348568  1.883043
##   8     2.167080  0.9194429  1.896898
##   9     2.189869  0.9308789  1.913791
## 
## RMSE was used to select the optimal model using the smallest value.
## The final value used for the model was mtry = 6.
plot(rf_random)

rf_random <- train(mpg~., data=mtcars, method="rf",  tuneLength=5, trControl=control)
print(rf_random)
## Random Forest 
## 
## 32 samples
## 10 predictors
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 3 times) 
## Summary of sample sizes: 28, 29, 30, 30, 28, 29, ... 
## Resampling results across tuning parameters:
## 
##   mtry  RMSE      Rsquared   MAE     
##   1     2.654707  0.8980172  2.269682
##   2     2.438300  0.9145830  2.112039
##   5     2.304727  0.9220358  1.979418
##   6     2.292552  0.9241559  1.962358
##   7     2.328859  0.9212295  1.984720
## 
## RMSE was used to select the optimal model using the smallest value.
## The final value used for the model was mtry = 6.
plot(rf_random)

Note: You cannot tune ntree as part of a tuneGrid for Random Forest in caret; only mtry, splitrule and min.node.size