[ml] Decision Tree模型

Aug 1, 2016   #machine learning  #tree  #cart 

1.Intro

决策树常用的算法有ID3,C4.5CART,学习过程分为三部分:特征选择决策树生成决策树剪枝

决策树的生成只考虑局部最优(贪心算法),而决策树的剪枝则考虑全局最优。简而言之,这是一个递归地选择最优特征,并根据该特征对训练数据进行分割,使得对每个子数据集有一个最好的分类过程

They are easy to interpret, handle categorical features, extend to the multiclass classification setting, do not require feature scaling, and are able to capture non-linearities and feature interactions.

2.准则函数

主要用来衡量节点数据集合的有序性,有熵(Entropy)基尼指数(Gini)方差(Variance),其中前两种针对分类问题,而方差针对回归问题。寻找最好的分割点是通过量化分割后类的纯度来确定的,计算方式总结如下表:

Impurity Task Description formula
分类 假设有K类,样本点属于第i类的概率为\(p_i\),度量数据的不确定性 \(E(D) =\sum_{i=1}^{K} -p_ilogp_i\)
基尼 分类 同上 \(Gini(D) = \sum_{i=1}^{K} p_i*(1-p_i)\)
方差 回归 度量数据的离散程度 \(Variance(D) = \frac{1}{N} \sum_{i=1}^N (y_i-\mu)^2\),\(\mu = \frac{1}{N} \sum_{i=1}^N y_i\)

信息增益: 得知某特征X的信息而使目标Y的信息不确定减少的程度。信息增益大的特征有更强的分类能力。

具体方法:对训练集D,计算每个特征的信息增益,比较大小,选择信息增益大的特征列来分割集合。假设分割点将数据集分成左右两部分,则信息增益为 \( IG(D,s) = Impurity(D) - \frac{N_{left}}{N} Impurity(D_{left}) - \frac{N_{right}}{N} Impurity(D_{right})\)

ID3s算法选择信息增益作分类,为了校正信息增益存在“偏向选择取值较多的特征”的缺点,C4.5采用信息增益比\(IG_R(D,s) = \frac{IG(D,s)}{E_s(D)}\), \(E_s(D)= \sum_{i=1}^K -\frac{|D_i|}{|D|}*log_2(\frac{|D_i|}{|D|})\), 其中K是特征s的取值个数。

3 生成算法

具体方法:从根节点开始,计算所有可能的特征的准则值(应用准则函数),选择最优的特征作为节点的特征,由该特征的不同取值建立子节点;再对子节点递归调用以上方法,构建决策树。

CART

CART(classification and regression tree),又称分类与回归树。

CART的生成是一个递归地构建二叉树的过程,对回归树应用平方误差最小化准则,对分类树用Gini指数最小化准则,进行特征选择,生成二叉树。预测时,先根据样本特征判断该样本点位于哪块区域,然后用该区域内的训练样本点集的众数 (分类) 或均值 (回归) 作为该样本点的预测值。

算法描述:CART树生成算法
输入: 训练数据集D,停止条件
输出: 决策树f(x)
(1)在平方误差最小或Gini最小的准则下选择最优切分变量j与切分点s,切分变量从可用特征集中选择。 回归树中,单元\(R_m\)下的最优值是\(R_m\)上所有输入实例对应的输出\(y_i\)的均值,即切分点。分类中,切分点是该变量所有可能的取值。
(2)用选定的对(j,s)划分区域并决定相应的输出值
(3)继续对子区域调用(1),(2),直至满足停止条件
(4)将输入空间划分为M个区域\(R_1\),\(R_2\),…,\(R_M\).生成如下决策树:

$$f(x) = \sum_{m=1}^{M} c_m*I(x \in R_m)$$

算法的停止条件是节点中的样本个数小于预定阈值,或样本集的基尼指数小于预订阈值(样本基本属于同一类),或没有更多特征。

4 剪枝过程(pruning)

将已生成的树进行简化的过程,通过极小化决策树整体的损失函数(loss function)来实现。 分为前向剪枝和后向剪枝。前向剪枝是在构造决策树的同时对树进行剪枝;后向剪枝是在决策树构建完成后,对树从根节点向上递归剪枝。

5.MLlib实现

连续值特征

对于连续值类型的特征,单台机器上的做法是选取所有出现过的值作为切分点候选集。加速计算的方法是对该集合按值大小排序,用排序好的数组做候选集。

在分布式中,对于规模大的数据集,排序是比较耗时的。于是先对整体数据按比例采样,然后采用分位数(等频分箱)作为分割点的近似候选集(优化点)。

采样策略为选取样本集中的max(maxBins*maxBins, 10000)条记录。等频分箱策略有三种(Sort,MinMax,ApproxHist),目前只支持第一种。规定分箱总数不能超过训练样本的实例数(maxBins的默认值为32)。

类别值特征

对于离散值类型的特征,有m个值,那么最多产生\(2^{m-1}-1\)种划分可能,\(2*(2^{m-1}-1)\)个分箱。对于二分类或回归问题,可以将划分候选缩减到m-1种可能。

举个例子,在二分类问题中,一个特征有A,B,C三个值。特征值为A的样本集中label是1的比例为0.2,B和C分别为0.6,0.4。则产生了一个类别值的有序序列ACB,共有两种划分可能:A|C,B,A,C|B。

多分类问题中,离散特征会产生\(2^{m-1}-1\)种划分可能,当划分数大于maxBins,将m个类别值根据impurity排序,产生共m-1种划分可能。

bin和split

split是划分点,bin是划分区间。每个特征都分别有一组bin和split。

其中bin中预先计算样本的统计信息,节省计算开销(优化点)。分类问题,bin中保存训练样本个数 (count) 以及各 label 数目 (count for label);回归问题,bin中保存训练样本个数 (count)、 label 之和 (sum) 以及 label 的平方和 (squared sum)。

特征选择时,对于某个特征,计算不同切分变量下信息增益的大小,确定出该特征下的最佳切分点。选择全部候选特征集中最优的切分点作为当前决策。特征选择是树的level级别的并行,对于同一层次的节点,查找可以并行,查找操作的时间复杂度为O(L),L为树的层数(优化点)。

终止条件

为了防止过拟合,采用前向剪枝。当任一以下情况发生,节点就终止划分,形成叶子节点: (1)树高度达到maxDepth
(2)minInfoGain,当前节点的所有属性分割带来的信息增益都比这个值要小
(3)minInstancesPerNode,需要保证节点分割出的左右子节点的最少的样本数量达到这个值

源码分析

下面就一起揭开decision tree的真正面纱吧!

输入 - RDD[LabelPoint]: LabelPoint类有两个属性label(标签值), features(特征向量)。将训练数据集转换为LabelPoint形式,作为决策树训练的输入之一。

输入 - Strategy: 这是关于模型训练过程中需要指定的参数集。

buildMetadata(): DecisionTreeMetadata是一个决策树元数据信息类,其中大部分属性和Strategy类重合。该函数的功能: 对于类别值特征,区分有序/无序;统计类别值特征的划分点数目和分箱数目。是训练决策树模型的一个准备工作。

def buildMetadata(input:RDD[LabeledPoint], strategy:Strategy,
numTrees: Int, featureSubsetStrategy: String):DecisionTreeMetadata =
{...}

对于每个类别值特征fi,多分类情况下:首先计算最大可能分箱maxPossibleBins确定情况下,能够允许的最多类别值m,如果fi的类别值超过m,则当作有序处理;否则认为该特征是无序的,并计算分箱数。

二分类/回归情况下:当作有序处理,即numBins(fi) = numCategories(fi)。

findSplitsBins

这是DecisionTree.scala中一个方法,基于buildMetadata统计的划分点数目和分箱数,加上数据采样,确定每个特征的划分点值。

如果存在连续值特征,则对数据集进行无放回采样,采样个数为max(maxBins*maxBins, 10000)。接着进行分箱操作(见findSplitsBinsBySorting函数,findSplitsForContinuousFeature函数)。

treeRDD,BaggedRDD: treeRDD的类型是:RDD[TreePoint],treePoint类有两个属性:label和binnedFeatures。BaggedPoint类有两个属性datum和subsampleWeights(在各个采样后数据集中的权重,即出现次数)。

def findSplitsBins(input:RDD[LabeledPoint],
metadata:DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) =
{ ... }

findBestSplits

对于当前节点,计算准则函数下的信息增益,选择最优切分特征和最优切分点。

其他:

configuration:

类名 枚举值
Algo Classification, Regression
FeatureType Continuous, Categorical
QuantileStrategy Sort, MinMax, ApproxHist
EnsembleCombiningStrategy Average, Sum, Vote

impurtiy,loss:

Impurity是准则函数的抽象类,其三个子类分别为Gini,Variance,Entropy。Impurities是Impurity的工厂类,负责解析用户参数,选择对应的准则函数。

private[mllib] object Impurities {
  def fromString(name: String): Impurity = name match {
    case "gini" => Gini
    case "entropy" => Entropy
    case "variance" => Variance
    case _ => throw new IllegalArgumentException(s"Did not recognize
    Impurity name: $name")
  }
}

Loss与Impurity情况类似,它是损失函数的抽象类,其三个子类分别是AbsoluteError,SquaredError,LogLoss。Losses是Loss的工厂类,负责判断损失的所属类别。

参考链接

spark mllib - decision tree

统计学习方法,李航著