监督学习——决策树理论与实践(下):回归决策树(CART)[通俗易懂]

监督学习——决策树理论与实践(下):回归决策树(CART)

大家好,又见面了,我是你们的朋友全栈君。

介绍

决策树分为分类决策树和回归决策树:

bacd

上一篇介绍了分类决策树以及Python实现分类决策树: 监督学习——决策树理论与实践(上):分类决策树

         决策树是一种依托决策而建立起来的一种树。在机器学习中,决策树是一种预测模型,代表的是一种对象属性与对象值之间的一种映射关系,每一个节点代表某个对象/分类,树中的每一个分叉路径代表某个可能的属性值,而每一个叶子节点则对应从根节点到该叶子节点所经历的路径所表示的对象的值

        通过训练数据构建决策树,可以高效的对未知的数据进行分类。决策数有两大优点:1)决策树模型可以读性好,具有描述性,有助于人工分析;2)效率高,决策树只需要一次构建,反复使用,每一次预测的最大计算次数不超过决策树的深度。

      决策树是一颗树形的数据结构,可以是多叉树也可以是二叉树,决策树实际上是一种基于贪心策略构造的,每次选择的都是最优的属性进行分裂。

      决策树也是一种监督学习算法,它的样本是(x,y)形式的输入输出样例。

  回归树:

         相对于上一篇所讲的决策树,这篇所讲的回归树主要解决回归问题,所以给定的训练数据输入和标签都是连续的。


CART回归树生成算法

决策树的生成

        CART算法的思路是将特征空间切分为m个不同的子空间,通过测试数据(落在每个子空间中的测试数据)来计算每个子空间的输出值(对应下式中的Cm)。当这样的空间几何生成之后就可以很方便的将一个未知数据映射到某一个子空间Ri中,将Ci的值作为该未知数据的输出值。

image

这里Cm的取值一般采用均值算法,即取所有落在该子空间的测试数据的均作为该子空间的值:

image

这里肯定会涉及到一个,这也是CART算法的关键: 如何去划分一个一个子空间?如何去选择第j个变量Xj和它取值s作为切分变量和切分点,并定义成两个区域。这里《统计学方法》中给出了算法思路:

image

算法实现时,比那里所有切分向量,切分点是测试数据在Xj上的所有取值集合。通过5.19就能计算出当前最佳的切分向量j和切分点x以及划分成的两个区域的取值c1,c2。(该部分的Python实现对应下文中chooseBestSplit函数)

      当对一个整体测试数据调用上面逻辑后会得到一个j和x值,通过这两个值将空间分成了两个空间,再分别对两个子空间调用上面的逻辑,这样递归下去就能生成一棵决策树。(对应下文中createTree函数

决策树的剪枝

CART剪枝算法从“完全生长”的决策树的底端减去一些子树,使决策树变小(模型变简单),从而能够对未知数据有更准确的预测。

后续待补充


CART算法Python实现

数据加载

加载测试数据,以及测试数据的值(X,Y),这里数据和值都存放在一个矩阵中。

def loadDataSet(fileName):      #general function to parse tab -delimited floats
    dataMat = []                #assume last column is target value
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = map(float,curLine) #map all elements to float()
        dataMat.append(fltLine)
    return dataMat

数据划分

该函数用于切分数据集,将测试数据某一列中的元素大于和小于的测试数据分开,分别放到两个矩阵中:

def binSplitDataSet(dataSet, feature, value):
    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]
    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]
    return mat0,mat1

输入参数 feature 为指定的某一列

value为切分点的值,通过该该值将dataset一份为二

寻找最优切分特征以及切分点

这里涉及到三个函数,分别在代码注释中进行了说明,真正计算最优值的函数为最后一个。

# 叶节点值计算函数: 这里以均值作为叶节点值
def regLeaf(dataSet):#returns the value used for each leaf
    return mean(dataSet[:,-1])

# 预测误差计算函数:这里用均方差表示
def regErr(dataSet):
    return var(dataSet[:,-1]) * shape(dataSet)[0]

# 遍历每一列中每个value值,找到最适合分裂的列和切分点
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    tolS = ops[0];   # 均方差最小优化值,如果大于该值则没有必要切分
    tolN = ops[1]    # 需要切分数据的最小长度,如果已经小于该值,则无需再切分
    #if all the target variables are the same value: quit and return value
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1
        return None, leafType(dataSet)
    m,n = shape(dataSet)
    #the choice of the best feature is driven by Reduction in RSS error from mean
    S = errType(dataSet)
    bestS = inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):
        for splitVal in set(dataSet[:,featIndex]):
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    #if the decrease (S-bestS) is less than a threshold don't do the split
    if (S - bestS) < tolS:
        return None, leafType(dataSet) #exit cond 2
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  #exit cond 3
        return None, leafType(dataSet)
    return bestIndex,bestValue#returns the best feature to split on
                              #and the value used for that split

回归树的创建

        在上面函数基础之上,创建一个回归树也就不难了:

def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filtering
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)#choose the best split
    if feat == None: return val #if the splitting hit a stop condition return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

这里是一个递归调用,需要注意函数的终止条件,这里当数据集不能再分时才会触发终止条件,实际中这种操作很有可能会出现过拟合,可以认为地加一些终止条件进行“预剪枝”

 

参考:

《机器学习实战》

《统计学习方法》

https://blog.csdn.net/u014568921/article/details/45082197

转载于:https://www.cnblogs.com/NeilZhang/p/9216354.html

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/107458.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)
blank

相关推荐

  • 网线RJ45接口排线示意图(做网线备用)「建议收藏」

    网线RJ45接口排线示意图(做网线备用)「建议收藏」网线RJ45接口排线示意图(做网线备用)RJ45有两种绕线方式,T-568A和T-568B。注意:绝大多数设备用的都是T-568B!!!请参照T-568B的线序!!!我的热门文章推荐多路视频直播用在线云导播切换的效果测试 如何把视频转换生成二维码,扫码直接播放? 有哪些网站上传视频是不会插入广告的? 怎么把视频生成二维码?微信扫二维码就可以观看?不要广告的 常用照片尺寸对照表,照片大小看这个表就对了 视频直播推流攻略(整理的各大平台推流界面) html5视频倍.

  • java语言算法描述_六大java语言经典算法[通俗易懂]

    java语言算法描述_六大java语言经典算法[通俗易懂]在程序员们进行编程的时候,对各种数据的处理是少不了的,java语言算法在这个时候就十分重要了。数据算法有很多种,也并不区分哪种计算机语言使用,但是有程序员们常用的java语言经典算法,下面就简单介绍一下六大经典java语言算法。一、冒泡排序(BubbleSort)1、基本思想:两个数比较大小,较大的数下沉,较小的数冒起来。2、算法描述:(1)比较相邻的元素。如果第一个比第二个大,就交换它们两个;…

  • 【12】进大厂必须掌握的面试题-持续测试面试

    点击上方“全栈程序员社区”,星标公众号 重磅干货,第一时间送达 Q1。什么是连续测试? 我将建议您遵循以下提到的解释:连续测试是作为软件交付管道的一部分执行自动测试的过程,以获得与…

  • mysql一键部署脚本

    mysql一键部署脚本

  • Pytest(1)安装与入门[通俗易懂]

    Pytest(1)安装与入门[通俗易懂]pytest介绍pytest是python的一种单元测试框架,与python自带的unittest测试框架类似,但是比unittest框架使用起来更简洁,效率更高。根据pytest的官方网站介绍,它

  • IOC 控制反转[通俗易懂]

    IOC 控制反转[通俗易懂]SpringFramework概述https://blog.csdn.net/centrl/article/details/115519480通过前面的学习,我们至少已经知道IOC,下面我们就来说说IOC是个什么东西。1.写在前面首先来想一件事,作为程序员,怎么开发程序才最巴适?我觉得最起码有两点:开发简单、升级简单。开发简单,就是我们只管写业务逻辑(培养只会写if-else的程序员)。 升级简单,这里也包含两点:我们使用的技术(可理解为框架)出了什么问…

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号