1.Intro
决策树常用的算法有
决策树的
They are easy to
2.准则函数
主要用来衡量节点数据集合的有序性,有
Impurity | Task | Description | formula |
---|---|---|---|
熵 | 分类 | 假设有K类,样本点属于第i类的概率为\(p_i\),度量数据的不确定性 | \(E(D) =\sum_{i=1}^{K} -p_ilogp_i\) |
基尼 | 分类 | 同上 | \(Gini(D) = \sum_{i=1}^{K} p_i*(1-p_i)\) |
方差 | 回归 | 度量数据的离散程度 | \(Variance(D) = \frac{1}{N} \sum_{i=1}^N (y_i-\mu)^2\),\(\mu = \frac{1}{N} \sum_{i=1}^N y_i\) |
信息增益: 得知某特征X的信息而使目标Y的信息
具体方法:对训练集D,计算每个特征的信息增益,比较大小,选择信息增益大的特征列来分割集合。假设分割点将数据集分成左右两部分,则信息增益为 \( IG(D,s) = Impurity(D) - \frac{N_{left}}{N} Impurity(D_{left}) - \frac{N_{right}}{N} Impurity(D_{right})\)
ID3s算法选择信息增益作分类,为了校正信息增益存在“偏向选择取值较多的特征”的缺点,C4.5采用
3 生成算法
具体方法:从根节点开始,计算所有可能的特征的准则值(应用准则函数),选择最优的特征作为节点的特征,由该特征的不同取值建立子节点;再对子节点递归调用以上方法,构建决策树。
CART
CART(classification and regression tree),又称分类与回归树。
CART的生成是一个递归地构建
算法描述:CART树生成算法
输入: 训练数据集D,停止条件
输出: 决策树f(x)
(1)在平方误差最小或Gini最小的准则下选择
(2)用选定的对(j,s)划分区域并决定相应的输出值
(3)继续对子区域调用(1),(2),直至满足停止条件
(4)将输入空间划分为M个区域\(R_1\),\(R_2\),…,\(R_M\).生成如下决策树:
$$f(x) = \sum_{m=1}^{M} c_m*I(x \in R_m)$$
算法的
4 剪枝过程(pruning)
将已生成的树进行简化的过程,通过极小化决策树整体的损失函数(loss function)来实现。 分为前向剪枝和后向剪枝。前向剪枝是在构造决策树的同时对树进行剪枝;后向剪枝是在决策树构建完成后,对树从根节点向上递归剪枝。
5.MLlib实现
连续值特征
对于连续值类型的特征,单台机器上的做法是选取所有出现过的值作为切分点候选集。加速计算的方法是对该集合按值大小排序,用排序好的数组做候选集。
在分布式中,对于规模大的数据集,排序是比较耗时的。于是先
类别值特征
对于离散值类型的特征,有m个值,那么最多产生\(2^{m-1}-1\)种划分可能,\(2*(2^{m-1}-1)\)个分箱。对于二分类或回归问题,可以将划分候选缩减到m-1种可能。
举个例子,在二分类问题中,一个特征有A,B,C三个值。特征值为A的样本集中label是1的比例为0.2,B和C分别为0.6,0.4。则产生了一个类别值的有序序列ACB,共有两种划分可能:A|C,B,A,C|B。
多分类问题中,离散特征会产生\(2^{m-1}-1\)种划分可能,当划分数大于maxBins,将m个类别值根据impurity排序,产生共m-1种划分可能。
bin和split
split是划分点,bin是划分区间。每个特征都分别有一组bin和split。
其中bin中
特征选择时,对于某个特征,计算不同切分变量下信息增益的大小,确定出该特征下的最佳切分点。选择全部候选特征集中最优的切分点作为当前决策。特征选择是树的level级别的并行,对于同一层次的节点,查找可以并行,查找操作的时间复杂度为O(L),L为树的层数(
终止条件
为了防止过拟合,采用前向剪枝。当任一以下情况发生,节点就终止划分,形成叶子节点:
(1)树高度达到maxDepth
(2)minInfoGain,当前节点的所有属性分割带来的信息增益都比这个值要小
(3)minInstancesPerNode,需要保证节点分割出的左右子节点的最少的样本数量达到这个值
源码分析
下面就一起揭开decision tree的真正面纱吧!
def buildMetadata(input:RDD[LabeledPoint], strategy:Strategy,
numTrees: Int, featureSubsetStrategy: String):DecisionTreeMetadata =
{...}
对于每个类别值特征fi,
findSplitsBins
def findSplitsBins(input:RDD[LabeledPoint],
metadata:DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) =
{ ... }
findBestSplits
configuration:
类名 | 枚举值 |
---|---|
Algo | Classification, Regression |
FeatureType | Continuous, Categorical |
QuantileStrategy | Sort, MinMax, ApproxHist |
EnsembleCombiningStrategy | Average, Sum, Vote |
impurtiy,loss:
Impurity是准则函数的抽象类,其三个子类分别为Gini,Variance,Entropy。Impurities是Impurity的工厂类,负责解析用户参数,选择对应的准则函数。
private[mllib] object Impurities {
def fromString(name: String): Impurity = name match {
case "gini" => Gini
case "entropy" => Entropy
case "variance" => Variance
case _ => throw new IllegalArgumentException(s"Did not recognize
Impurity name: $name")
}
}
Loss与Impurity情况类似,它是损失函数的抽象类,其三个子类分别是AbsoluteError,SquaredError,LogLoss。Losses是Loss的工厂类,负责判断损失的所属类别。
参考链接
统计学习方法,李航著