Main Content

Train Decision Trees Using Classification Learner App

This example shows how to create and compare various classification trees using Classification Learner, and export trained models to the workspace to make predictions for new data.

You can train classification trees to predict responses to data. To predict a response, follow the decisions in the tree from the root (beginning) node down to a leaf node. The leaf node contains the response.

Statistics and Machine Learning Toolbox™ trees are binary. Each step in a prediction involves checking the value of one predictor (variable). For example, here is a simple classification tree:

Decision tree with two branches

This tree predicts classifications based on two predictors,x1andx2. To predict, start at the top node. At each decision, check the values of the predictors to decide which branch to follow. When the branches reach a leaf node, the data is classified either as type0or1.

  1. In MATLAB®,load thefisheririsdata set and create a table of measurement predictors (or features) using variables from the data set to use for a classification.

    fishertable = readtable("fisheriris.csv");
  2. On theAppstab, in theMachine Learning and Deep Learninggroup, clickClassification Learner.

  3. On theClassification Learnertab, in theFilesection, clickNew Session > From Workspace.

    Classification Learner tab

  4. In the New Session from Workspace dialog box, select the tablefishertablefrom theData Set Variablelist (if necessary).

    Observe that the app has selected response and predictor variables based on their data type. Petal and sepal length and width are predictors, and species is the response that you want to classify. For this example, do not change the selections.

    New Session from Workspace dialog box

  5. To accept the default validation scheme and continue, clickStart Session. The default validation option is cross-validation, to protect against overfitting.

    Classification Learner creates a scatter plot of the data.

    Scatter plot of the Fisher iris data

  6. Use the scatter plot to investigate which variables are useful for predicting the response. To visualize the distribution of species and measurements, select different variables in theXandYlists under thePredictorssection to the right of the plot. Observe which variables separate the species colors most clearly.

    Observe that thesetosaspecies (blue points) is easy to separate from the other two species with all four predictors. Theversicolorandvirginicaspecies are much closer together in all predictor measurements, and overlap especially when you plot sepal length and width.setosais easier to predict than the other two species.

  7. Train fine, medium, and coarse trees simultaneously. TheModelspane already contains a fine tree model. Add medium and coarse tree models to the list of draft models. On theClassification Learnertab, in theModelssection, click the arrow to open the gallery. In theDecision Treesgroup, clickMedium Tree. The app creates a draft medium tree in theModelspane. Reopen the model gallery and clickCoarse Treein theDecision Treesgroup. The app creates a draft coarse tree in theModelspane.

    In theTrainsection, clickTrain Alland selectTrain All. The app trains the three tree models.

    Note

    • If you have Parallel Computing Toolbox™, then the app has theUse Parallelbutton toggled on by default. After you clickTrain Alland selectTrain AllorTrain Selected,the app opens a parallel pool of workers. During this time, you cannot interact with the software. After the pool opens, you can continue to interact with the app while models train in parallel.

    • If you do not have Parallel Computing Toolbox, then the app has theUse Background Trainingcheck box in theTrain Allmenu selected by default. After you click to train models, the app opens a background pool. After the pool opens, you can continue to interact with the app while models train in the background.

    Validation confusion matrix for a coarse tree regression model. Blue values indicate correct classifications, and red values indicate incorrect classifications.

    Note

    Validation introduces some randomness into the results. Your model validation results can vary from the results shown in this example.

  8. In theModelspane, each model has a validation accuracy score that indicates the percentage of correctly predicted responses. The app highlights the highestAccuracy (Validation)score (or scores) by outlining it in a box.

    Click a model to view the results, which are displayed in theSummarytab. On theClassification Learnertab, in theModelssection, clickSummary.

  9. For each model, examine the scatter plot. On theClassification Learnertab, in thePlots部分,单击箭头打开画廊,then clickScatterin theValidation Resultsgroup. An X indicates misclassified points.

    For all three models, the blue points (setosaspecies) are all correctly classified, but some of the other two species are misclassified. UnderPlot,switch between theDataandModel Predictionsoptions. Observe the color of the incorrect (X) points. Alternatively, while plotting model predictions, to view only the incorrect points, clear theCorrectcheck box.

  10. To try to improve the models, include different features during model training. See if you can improve the model by removing features with low predictive power.

    On theClassification Learnertab, in theOptionssection, clickFeature Selection.

    In theDefault Feature Selectiontab, you can select different feature ranking algorithms to determine the most important features. After you select a feature ranking algorithm, the app displays a plot of the sorted feature importance scores, where larger scores (includingInfs) indicate greater feature importance. The table shows the ranked the features and their scores.

    In this example, theChi2ReliefFANOVA,andKruskal Wallisfeature ranking algorithms all identify the petal measurements as the most important features. UnderFeature Ranking Algorithm,clickChi2.

    Default Feature Selection tab with Chi2 as the selected feature ranking algorithm

    UnderFeature Selection,use the default option of selecting the highest ranked features to avoid bias in the validation metrics. Specify to keep 2 of the 4 features for model training. ClickSave and Apply. The app applies the feature selection changes to new models created using theModelsgallery.

  11. Train new tree models using the reduced set of features. On theClassification Learnertab, in theModelssection, click the arrow to open the gallery. In theDecision Treesgroup, clickAll Tree. In theTrainsection, clickTrain Alland selectTrain AllorTrain Selected.

    The models trained using only two measurements perform comparably to the models containing all predictors. The models predict no better using all the measurements compared to only the two measurements. If data collection is expensive or difficult, you might prefer a model that performs satisfactorily without some predictors.

  12. Observe the last model in theModelspane. It is aCoarse Treemodel, trained using only 2 of 4 predictors. The app displays how many predictors are excluded. To check which predictors are included, click the model in theModelspane, and observe the check boxes in the expandedFeature Selectionsection of the modelSummarytab.

    Note

    If you use a cross-validation scheme and choose to perform feature selection using theSelect highest ranked featuresoption, then for each training fold, the app performs feature selection before training a model. Different folds can choose different predictors as the highest ranked features. The table shows the list of predictors used by the full model, trained on the training and validation data.

  13. Train new tree models using another subset of measurements. On theClassification Learnertab, in theOptionssection, clickFeature Selection. In theDefault Feature Selectiontab, clickMRMRunderFeature Ranking Algorithm. UnderFeature Selection,specify to keep 3 of the 4 features for model training. ClickSave and Apply.

    On theClassification Learnertab, in theModelssection, click the arrow to open the gallery. In theDecision Treesgroup, clickAll Tree. In theTrainsection, clickTrain Alland selectTrain AllorTrain Selected.

    The models trained using only 3 of 4 predictors do not perform as well as the other trained models.

  14. Choose a best model among those of similar scores by examining the performance in each class. For example, select the coarse tree that includes 2 of 4 predictors. Inspect the accuracy of the predictions in each class. On theClassification Learnertab, in thePlots部分,单击箭头打开画廊,then clickConfusion Matrix (Validation)in theValidation Resultsgroup. Use this plot to understand how the currently selected classifier performed in each class. View the matrix of true class and predicted class results.

    Look for areas where the classifier performed poorly by examining cells off the diagonal that display high numbers and are red. In these red cells, the true class and the predicted class do not match. The data points are misclassified.

    Confusion matrix plot

    In this figure, examine the third cell in the middle row. In this cell, true class isversicolor,but the model misclassified the points asvirginica. For this model, the cell shows 2 misclassified (your results can vary). To view percentages instead of numbers of observations, select theTrue Positive Ratesoption underPlotcontrols.

    You can use this information to help you choose the best model for your goal. If false positives in this class are very important to your classification problem, then choose the best model at predicting this class. If false positives in this class are not very important, and models with fewer predictors do better in other classes, then choose a model to tradeoff some overall accuracy to exclude some predictors and make future data collection easier.

  15. Compare the confusion matrix for each model in theModelspane. Check theFeature Selectionsection of the modelSummary选项卡,查看哪些预测包含在每个model.

    In this example, the coarse tree that includes 2 of 4 predictors performs as well as the coarse tree with all predictors. That is, both models provide the same validation accuracy and have the same confusion matrices.

  16. To further investigate features to include or exclude, use the parallel coordinates plot. On theClassification Learnertab, in thePlots部分,单击箭头打开画廊,then clickParallel Coordinatesin theValidation Resultsgroup. You can see that petal length and petal width are the features that separate the classes best.

    Parallel coordinates plot

  17. To learn about model hyperparameter settings, choose a model in theModelspane and expand theModel Hyperparameterssection in the modelSummarytab. Compare the coarse and medium tree models, and observe the differences in the model hyperparameters. In particular, theMaximum number of splitssetting is 4 for coarse trees and 20 for medium trees. This setting controls the tree depth.

    To try to improve the coarse tree model further, change theMaximum number of splitssetting. First, click the model in theModelspane. On theClassification Learnertab, in theModelssection, clickDuplicate. In theSummarytab, change theMaximum number of splitsvalue. Then, in theTrainsection of theClassification Learner,clickTrain Alland selectTrain Selected.

  18. To export the best trained model to the workspace, on theClassification Learnertab, in theExportsection, clickExport Modeland selectExport Model. In the Export Model dialog box, clickOKto accept the default variable nametrainedModel.

    Look in the command window to see information about the results.

  19. To visualize your decision tree model, enter:

    view(trainedModel.ClassificationTree,"Mode""graph")

    Classification tree

  20. You can use the exported classifier to make predictions on new data. For example, to make predictions for thefishertabledata in your workspace, enter:

    yfit = trainedModel.predictFcn(fishertable)
    The outputyfitcontains a class prediction for each data point.

  21. If you want to automate training the same classifier with new data, or learn how to programmatically train classifiers, you can generate code from the app. To generate code for the best trained model, on theClassification Learnertab, in theExportsection, clickGenerate Function.

    The app generates code from your model and displays the file in the MATLAB Editor. To learn more, seeGenerate MATLAB Code to Train the Model with New Data.

This example uses Fisher's 1936 iris data. The iris data contains measurements of flowers: the petal length, petal width, sepal length, and sepal width for specimens from three species. Train a classifier to predict the species based on the predictor measurements.

Use the same workflow to evaluate and compare the other classifier types you can train in Classification Learner.

To try all the nonoptimizable classifier model presets available for your data set:

  1. On theClassification Learnertab, in theModelssection, click the arrow to open the gallery of classification models.

  2. In theGet Startedgroup, clickAll. Then, in theTrainsection, clickTrain Alland selectTrain All.

    Option selected for training all available classifier types

To learn about other classifier types, seeTrain Classification Models in Classification Learner App.

Related Topics