Data

We will use data from the post-election survey of the 2005 British Election Study (BES). You can download the data here. For more information on the BES, please visit http://www.britishelectionstudy.com/.

Our data set includes the following variables:

Classification Trees

Our goal is to predict voter turnout in the 2005 British general election. We use the tree() function from the tree package to build a classification tree. The summary() function lists the variables that are used as internal nodes in the tree, the number of terminal nodes, and the (training) error rate. We see that the training error rate is 18.6%.

library(tree)

# Load data set
load("bes_2005.RData")

# Build a classification tree
tree_1 <- tree(vote_2005 ~ . -household_income, data = bes_2005)
summary(tree_1)
## 
## Classification tree:
## tree(formula = vote_2005 ~ . - household_income, data = bes_2005)
## Variables actually used in tree construction:
## [1] "duty"      "vote_2001" "education"
## Number of terminal nodes:  6 
## Residual mean deviance:  0.8614 = 811.4 / 942 
## Misclassification error rate: 0.1857 = 176 / 948

We can use the plot() function to graphically display the tree structure, and the text() function to display the node labels. The argument pretty = 0 instructs R to include the category names for qualitative predictors, rather than simply displaying a letter for each category.

# Graphically display the tree
plot(tree_1)
text(tree_1, pretty = 0)

If we type the name of the tree object, R prints output corresponding to each branch of the tree. R displays the split criterion, the number of observations in that branch, the deviance, the overall prediction for the branch (Yes or No), and the fraction of observations in that branch that take on the value of Yes and No, respectively. Branches that lead to terminal nodes are by an asterisk.

# Examine the tree object
tree_1
## node), split, n, deviance, yval, (yprob)
##       * denotes terminal node
## 
##  1) root 948 1100.00 1 ( 0.26688 0.73312 )  
##    2) duty: 3,4,5 186  236.80 0 ( 0.66667 0.33333 )  
##      4) vote_2001: 0 91   87.65 0 ( 0.81319 0.18681 ) *
##      5) vote_2001: 1 95  131.40 0 ( 0.52632 0.47368 ) *
##    3) duty: 1,2 762  693.10 1 ( 0.16929 0.83071 )  
##      6) vote_2001: 0 83  115.10 0 ( 0.50602 0.49398 )  
##       12) education: 5,6,7 21   17.22 0 ( 0.85714 0.14286 ) *
##       13) education: 1,2,3,4 62   82.76 1 ( 0.38710 0.61290 ) *
##      7) vote_2001: 1 679  519.90 1 ( 0.12813 0.87187 )  
##       14) duty: 1 321  144.30 1 ( 0.05919 0.94081 ) *
##       15) duty: 2 358  348.10 1 ( 0.18994 0.81006 ) *

In order to evaluate the performance of the classification tree, we must estimate the test error rather than simply computing the training error. We split the observations into a training set and a test set, build the tree using the training set, and evaluate its performance on the test data.

set.seed(1234)

# Create training and test sets
train <- sample(1:nrow(bes_2005), size = as.integer(nrow(bes_2005) / 2))
bes_2005_test <- bes_2005[-train, ]
vote_2005_test <- bes_2005$vote_2005[-train]

# Grow tree on training data
tree_2 <- tree(vote_2005 ~ . , data = bes_2005, subset = train)

We can use the predict() function to predict outcomes.

# Predict outcomes
tree_2_pred <- predict(tree_2, newdata = bes_2005_test, type = "class")

# Confusion matrix
table(prediction = tree_2_pred, truth = vote_2005_test)
##           truth
## prediction   0   1
##          0  58  19
##          1  76 321

We correctly classify approximately 80% of the observations in the test data set.

# Percent correctly classified
mean(tree_2_pred == vote_2005_test)
## [1] 0.7995781

Next, we use cost-complexity pruning to see if we can simplify the tree and thus decrease variance without increasing bias. We use k-fold cross-validation to determine the optimal size of the tree.

set.seed(1234)
cv_tree_2 <- cv.tree(tree_2, FUN = prune.misclass)

# Illustrate
par(mfrow = c(1, 2))
plot(cv_tree_2$size, cv_tree_2$dev, type = "b")
plot(cv_tree_2$k, cv_tree_2$dev, type = "b")

The tree can be pruned to four terminal nodes. We use the prune.misclass() function to prune the tree.

prune_tree_2 <- prune.misclass(tree_2, best = 4)
par(mfrow = c(1, 1))
plot(prune_tree_2)
text(prune_tree_2, pretty = 0)

How well does this pruned tree perform on the test data set? Once again, we apply the predict() function.

# Predict outcomes
prune_tree_2_pred <- predict(prune_tree_2, newdata = bes_2005_test, type = "class")

# Confusion matrix
table(prediction = prune_tree_2_pred, truth = vote_2005_test)
##           truth
## prediction   0   1
##          0  51  11
##          1  83 329

Now, 80.2% of the test observations are correctly classified, so the pruning process slightly improved the classification accuracy.

# Percent correctly classified
mean(prune_tree_2_pred == vote_2005_test)
## [1] 0.8016878

Regression Trees

Here our goal is to predict the household income of respondents. To do so, we build a regression tree.

# Build a regression tree
tree_3 <- tree(household_income ~ ., data = bes_2005, subset = train)
summary(tree_3)
## 
## Regression tree:
## tree(formula = household_income ~ ., data = bes_2005, subset = train)
## Variables actually used in tree construction:
## [1] "education" "vote_2001" "duty"     
## Number of terminal nodes:  6 
## Residual mean deviance:  7.935 = 3713 / 468 
## Distribution of residuals:
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
## -6.7160 -2.0500 -0.7159  0.0000  1.6890  9.6890

We again use the plot() function to plot the tree.

# Graphically display the tree
plot(tree_3)
text(tree_3, pretty = 0)

We estimate the test error of the regression tree.

# Predict outcomes
tree_3_pred <- predict(tree_3, newdata = bes_2005_test)

# MSE
household_income_test <- bes_2005$household_income[-train]
mean((tree_3_pred - household_income_test)^2)
## [1] 10.61289

Now we use the cv.tree() function to see whether pruning the tree will improve performance.

cv_tree_3 <- cv.tree(tree_3)

# Illustrate
plot(cv_tree_3$size, cv_tree_3$dev, type = "b")

The tree with 3 terminal nodes results in the lowest cross-validation error rate. We apply the prune.misclass() function in order to prune the tree to obtain the 3-node tree.

prune_tree_3 <- prune.tree(tree_3, best = 3)
plot(prune_tree_3)
text(prune_tree_3, pretty = 0)