hands on: 06 decision tree

Train and Visualize a Decision Tree

Try to understand how it makes predictions.

from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz
iris = load_iris()
X = iris.data[:, 2:] # petal len and wid
y = iris.target
tree_clf = DecisionTreeClassifier(max_depth=2)
tree_clf.fit(X, y)

export_graphviz(tree_clf, out_file=image_path("tree.dot"), feature_names=iris.feature_names[2:], rounded=True, filled=True)

Then, we can use the dot command line too from the Graphviz package to convert the file to other formats.

dot -Tpng tree.dot -o iris_tree.png

Decision Trees are intuitive, thus are called white box models. Random Forests or NN are considered black box models.

Making Predictions

Decision Tree models do not require much data preparation (feature scaling or centering).

A node’s samples attribute counts how many training instances it applies to; value attribute tells you how many training instances of each class this node applies to; gini attribute measures its impurity: a node is pure if all training instances it applies to belong to the same class.

Gini impurity $G_i = 1 - \sum_{k=1}^n p_{i,k}^2$

$p_{i,k}$ is the ratio of class k instances among the training instances in the i-th node.

Estimating Class Prob

To estimate the probability that an instance belongs to a particular class k, a DT traverses the tree to find the leaf node for this instance, and then it returns the ration of training instances of class k in this node.

CART Training Alg

CART: Classification and Regression Tree

Scikit-Learn uses the CART algorithm, which produces only binary trees. The alg works by first splitting the training set into two subsets using a single feature k and a threshold $t_k$. It searches for the pair (k, $t_k$) that produces the purest subsets.

Cost function for classification

$J(k, t_k) = \frac{1}{m}(m_{left}*G_{left}+m_{right}*G_{right})$

It stops recursion once it reaches the maximum depth or if it cannot find a split that will reduce impurity. This is a greedy algorithm: it does not check whether the split will lead to the lowest possible impurity several levels down.

Computational Complexity

Making prediction requires traversing the DT requires roughly $O(log_2(m))$. The training algorithm compares all features on all samples at each other, which results in $O(nmlog_2(m))$. $n$ stands for the number of features, $m$ stands for the number of data samples.

Gini Impurity or Entropy

By default, the Gini impurity measure is used. We can also use entropy impurity measure by setting the criterion to “entropy”.


$H_i = -\sum_{k=1, p_{i,k}\ne0}^n p_{i,k} log_2(p_{i,k})$

In most cases, they produce similar trees. When they differ, Gini impurity tends to isolate the most frequent class in its own branch of the tree, while entropy tends to produce slightly more balanced trees.

Regularization Hyperparameter

A parametric model, such as a linear model, has a predetermined number of parameters, so its degree of freedom is limited. DT is a nonparametric model, which is more likely to overfit training data.

Increasing min_* hyperparams or reduce max_* hyperparams will regularize the model. Other alg work by first training DT w/o restrictions, then pruning unnecessary nodes. Standard statistical tests such as chi-squared test can be used to estimate the nodes’ significance.


from sklearn.tree import DecisionTreeRegressor
tree_reg = DecisionTreeRegressor(max_depth=2)
tree_reg.fit(X, y)

The CART tries to split the training set in a way that minimizes the MSE.


DT love orthogonal decision boundaries, making them sensitive to training set rotation. One solution is to use PCA, which often results in a better orientation.

The main issue with DT is that they are sensitive to small variations. Besides, we can get very different models even on the same training data.

Author: csy99
Reprint policy: All articles in this blog are used except for special statements CC BY 4.0 reprint polocy. If reproduced, please indicate source csy99 !
ml basic: 02 feature engineering ml basic: 02 feature engineering
特征工程2.1 特征提取将原始数据转化为实向量之后,为了 让模型更好地学习规律,对特征做进一步的变换。首先,要理解业务数据和业务逻辑。 其次,要理解模型和算法,清楚模型需要什么样的输入才能有精确的结果。 2.1.1 探索性数据分析Explo
hands on: 05 svm hands on: 05 svm
Linear SVM ClassificationSVM are sensitive to the feature scales. Soft Margin ClassificationIf we strictly impose that