Grow and prune a decision tree
Contents
We use the MAGIC telescope data from the UCI repository http://archive.ics.uci.edu/ml/datasets/MAGIC+Gamma+Telescope. These are simulated data for detection of high energy gamma particles in a ground-based atmospheric Cherenkov gamma telescope using the imaging technique. The goal is to separate gamma particles from hadron background.
We split the data into training and test sets in the proportion 2:1 and code class labels as a numeric vector: +1 for gamma and -1 for hadron. In this example, we use only the training set. Matrix Xtrain has about 13k rows (observations) and 10 columns (variables). Vector Ytrain holds class labels.
load MagicTelescope; Ytrain = 2*strcmp('Gamma',Ytrain)-1; size(Xtrain)
ans = 12742 10
Grow a deep tree
We use the ClassificationTree class available from the Statistics Toolbox in MATLAB. By default, ClassificationTree.fit uses at least 10 observations for branch nodes and imposes no restriction on the leaf size.
tree = ClassificationTree.fit(Xtrain,Ytrain);
The tree is deep as evidenced by the number of nodes.
tree.NumNodes
ans = 1651
The tree is pruned by default. The PruneList property of the tree object shows pruning levels for tree nodes. If the pruning level of a node is L, this node is removed when the tree is pruned to level L+1. The pruning levels range from 0 (no pruning) to the maximal level at which the tree is reduced to its root.
min(tree.PruneList)
ans = 0
max(tree.PruneList)
ans = 67
Find the optimal pruning level by cross-validation
We search for the optimal pruning level using 10-fold cross-validation. The result could be sensitive to how exactly the data are partitioned in 10 folds. We would like to find the best pruning level for the "minimal risk" rule and for the "one standard error" rule. Since we would like to use the same data partition for both rules, we set the seed for the random number generator for reproducibility.
The cvLoss method of the tree object returns a cross-validated estimate of the classification error E, its standard deviation SE based on the binomial approximation, number of leaf nodes at each pruning level nLeaf, and the best pruning level BestMin chosen by the "minimal risk" rule. The quantities E, SE, and nLeaf are vectors with max(tree.PruneList)+1 elements. By default, cvLoss uses 10-fold cross-validation.
rng(1); [E,SE,nLeaf,BestMin] = cvLoss(tree,'subtrees','all','treesize','min');
We reset the random seed to partition the data into the same cross-validation folds. We then obtain the best pruning level using the "one standard error" rule. We do not need to recompute E, SE and nLeaf because they are fixed by the data partition.
rng(1); [~,~,~,BestSE] = cvLoss(tree,'subtrees','all','treesize','se');
Show the pruning level, number of leaves and cross-validated error obtained by the "minimal risk" rule. We offset the level by 1 because the levels are counted from zero.
BestMin
BestMin = 40
nLeaf(BestMin+1)
ans = 123
E(BestMin+1)
ans = 0.1545
Plot cross-validation results
Reproduce the figure in the book.
The first element in the E, SE, and nLeaf arrays is for the zeroth pruning level. To obtain values for the best pruning level, offset by 1.
figure; errorbar(nLeaf,E,SE,'ks'); xlo = 0; xhi = 200; ylo = 0.15; yhi = 0.18; grid on; axis([xlo xhi ylo yhi]); line([nLeaf(BestMin+1) nLeaf(BestMin+1)],[ylo yhi],... 'color','r','LineStyle','--','LineWidth',3); line([nLeaf(BestSE+1) nLeaf(BestSE+1)],[ylo yhi],... 'color','r','LineStyle','-','LineWidth',3); line([xlo xhi],[E(BestMin+1) E(BestMin+1)],... 'color','r','LineStyle','-.','LineWidth',3); xlabel('Number of leaf nodes'); ylabel('Cross-validated risk');
