rm(list=ls())
We first use classification trees to analyze the Carseats
data set. The main goal is to predict the Sales of Carseats and find important features that influence the sales. In these
data, Sales is a continuous variable, and so we begin by recoding it as a
binary variable. We use the ifelse()
function to create a variable, called
ifelse()
High, which takes on a value of Yes if the Sales variable exceeds 8, and
takes on a value of No otherwise.
library (ISLR)
library(plotly)
A data frame with 400 observations on the following 11 variables.
Sales
- Unit sales (in thousands) at each locationCompPrice
- Price charged by competitor at each locationIncome
- Community income level (in thousands of dollars)Advertising
- Local advertising budget for company at each location (in thousands of dollars)Population
- Population size in region (in thousands)Price
- Price company charges for car seats at each siteShelveLoc
- A factor with levels Bad, Good and Medium indicating the quality of the shelving location for the car seats at each siteAge
-Average age of the local populationEducation
-Education level at each locationUrban
-A factor with levels No and Yes to indicate whether the store is in an urban or rural locationUS
-A factor with levels No and Yes to indicate whether the store is in the US or nothead(Carseats)
sum(is.na(Carseats))
summary(Carseats)
In these data, Sales is a continuous variable, and so we begin by recoding it as a
binary variable. We use the ifelse()
function to create a variable, called
High
, which takes on a value of Yes
if the Sales
variable exceeds 8, and
takes on a value of No
otherwise.
Carseats$Income
attach(Carseats)
Income
Carseats$High=ifelse(Sales <=8,0,1)
factor(High)
Finally, we use the data.frame()
function to merge High with the rest of
the Carseats
data.
head(Carseats)
plot(Carseats$Price,Carseats$Advertising,col=c("red","blue")[Carseats$High])
plot(Advertising,ShelveLoc,col=c("red","blue")[High])
plot_ly(Carseats, x = ~Price, y = ~Advertising, z = ~ShelveLoc, color = ~High, colors = c('#BF382A', '#0C4B8E')) %>%
add_markers() %>%
layout(scene = list(xaxis = list(title = ''),
yaxis = list(title = ''),
zaxis = list(title = '')))
l <- glm(High~. -Sales, data = Carseats, family=binomial)
summary(l)
exp(0.295)
1-exp(-0.383)
This fitted model says that, we will see $34\%$ increase in the odds of High Sales for a one-unit increase in Advertising
,since $exp(0.295) = 1.34$, holding the rest fix. holding the rest at a fixed value, the odds of having High Sales for Stores in Urban areas (Urban = 1) over the odds of having High Sales for non Urban stores is $exp(-0.383) = 0.681$. In terms of percent change, we can say that the odds for Urban stores having high sales are 32% lower than the odds for non Urban.
set.seed(33)
#data<-read.csv("default1.csv")
test.index <- sample(1:nrow(Carseats), size = 150, replace = F)
test <- Carseats[test.index,]
train <- Carseats[-test.index,]
l <- glm(High~. -Sales, data = train, family=binomial)
l.pred <- predict(l, train, type = "response")
l.pred1 <- as.numeric(l.pred > 0.5)
l.table<-table(l.pred1 ,train$High )
l.table
sum(l.pred1)
(139+82)/250
summary.pred<-function(x)
{
FPR<-x[1,2]/sum(x[1,])
TPR<-x[2,2]/sum(x[2,])
PPR<-x[2,2]/sum(x[,2])
NPR<-x[1,1]/sum(x[,1])
output<-list("FPR"=round(FPR,2),"TPR"=round(TPR,2),"PPR"=round(PPR,2),"NPR"=round(NPR,2))
return(output)
}
l.pred.summary<-summary.pred(l.table)
l.pred.summary
l.pred <- predict(l, test, type = "response")
l.pred1 <- as.numeric(l.pred > 0.5)
l.table.test<-table(l.pred1 ,test$High )
l.table.test
(80+56)/150
l.test.pred.summary<-summary.pred(l.table)
l.test.pred.summary
We now use the tree()
function to fit a classification tree in order to predict
tree()
High using all variables but Sales
. The syntax of the tree()
function is quite
similar to that of the lm()
function.
library(tree)
tree.carseats =tree(High∼.-Sales ,Carseats )
summary (tree.carseats )
We see that the training error rate is 9 %. For classification trees, the deviance
reported in the output of summary() is given by
$$-2 \sum_m \sum_k n_{mk}\log \hat p_{mk}$$
where $n_{mk}$ is the number of observations in the $m$th terminal node that
belong to the $k$th class. A small deviance indicates a tree that provides
a good fit to the (training) data. The residual mean deviance reported is
simply the deviance divided by $n−|T_0|$, which in this case is $400−27 = 373$.
One of the most attractive properties of trees is that they can be
graphically displayed. We use the plot()
function to display the tree structure,
and the text()
function to display the node labels. The argument
pretty=0
instructs R
to include the category names for any qualitative predictors,
rather than simply displaying a letter for each category.
plot(tree.carseats )
text(tree.carseats ,pretty =0)
The most important indicator of Sales
appears to be shelving location,
since the first branch differentiates Good
locations from Bad
and Medium
locations.
If we just type the name of the tree object, R
prints output corresponding
to each branch of the tree. R
displays the split criterion (e.g. $Price<92.5$), the
number of observations in that branch, the deviance, the overall prediction
for the branch (Yes
or No
), and the fraction of observations in that branch
that take on values of Yes
and No
. Branches that lead to terminal nodes are
indicated using asterisks.
tree.carseats
In order to properly evaluate the performance of a classification tree on
these data, we must estimate the test error rather than simply computing
the training error. We split the observations into a training set and a test
set, build the tree using the training set, and evaluate its performance on
the test data. The predict()
function can be used for this purpose. In the
case of a classification tree, the argument type="class"
instructs R
to return
the actual class prediction. This approach leads to correct predictions for
around $80\%$ of the locations in the test data set.
tree.carseats =tree(High∼.-Sales ,train )
tree.pred=predict (tree.carseats ,test ,type ="class")
t.table<-table(tree.pred ,test$High)
t.table
(97+49)/200
t.test.pred.summary<-summary.pred(t.table)
t.test.pred.summary
Next, we consider whether pruning the tree might lead to improved
results. The function cv.tree()
performs cross-validation in order to
cv.tree()
determine the optimal level of tree complexity; cost complexity pruning
is used in order to select a sequence of trees for consideration. We use
the argument FUN=prune.misclass
in order to indicate that we want the
classification error rate to guide the cross-validation and pruning process,
rather than the default for the cv.tree()
function, which is deviance. The
cv.tree()
function reports the number of terminal nodes of each tree considered
(size
) as well as the corresponding error rate and the value of the
cost-complexity parameter used (k
, which corresponds to $\alpha$(Lokk formula in Lecture Notes)).
set.seed (3)
cv.carseats =cv.tree(tree.carseats ,FUN=prune.misclass )
names(cv.carseats )
cv.carseats
Note that, despite the name, dev
corresponds to the cross-validation error
rate in this instance. The tree with $9$ terminal nodes results in the lowest
cross-validation error rate, with $50$ cross-validation errors.We plot the error
rate as a function of both size and $k$.
par(mfrow =c(1,2))
plot(cv.carseats$size ,cv.carseats$dev ,type="b")
plot(cv.carseats$k ,cv.carseats$dev ,type="b")
We now apply the prune.misclass()
function in order to prune the tree to prune.
obtain the nine-node tree.
prune.carseats =prune.misclass (tree.carseats ,best =5)
plot(prune.carseats )
text(prune.carseats ,pretty =0)
How well does this pruned tree perform on the test data set? Once again,
we apply the predict()
function.
Next, we consider whether pruning the tree might lead to improved
results. The function cv.tree()
performs cross-validation in order to
cv.tree()
determine the optimal level of tree complexity; cost complexity pruning
is used in order to select a sequence of trees for consideration. We use
the argument FUN=prune.misclass
in order to indicate that we want the
classification error rate to guide the cross-validation and pruning process,
rather than the default for the cv.tree()
function, which is deviance. The
cv.tree()
function reports the number of terminal nodes of each tree considered
(size
) as well as the corresponding error rate and the value of the
cost-complexity parameter used (k
, which corresponds to $\alpha$ ).
tree.pred=predict (prune.carseats , test ,type="class")
t.table<-table(tree.pred ,test$High)
t.table
(64+49)/150
t.test.pred.summary<-summary.pred(t.table)
t.test.pred.summary
Now $75\%$ of the test observations are correctly classified, so not only has the pruning process produced a more interpretable tree, but it has also improved the classification accuracy.
dim(train)
library(randomForest)
set.seed (1)
bag.model =randomForest(High∼.,data=train ,
mtry=11, importance =TRUE)
bag.model
bag.pred = predict (bag.model ,newdata =test)
b.table<-table(bag.pred ,test$High)
b.table
(83+67)/150
b.test.pred.summary<-summary.pred(b.table)
b.test.pred.summary
set.seed (1)
rf.model =randomForest(High∼.,data=train ,
mtry=4, importance =TRUE)
rf.model
pred.rf = predict (rf.model,newdata =test)
rf.table<-table(pred.rf ,test$High)
rf.table