kfold_机器学习gridsearchcv(网格搜索)和kfold validation(k折验证)

kfold_机器学习gridsearchcv(网格搜索)和kfold validation(k折验证)网格搜索算法是一种通过遍历给定的参数组合来优化模型表现的方法。以决策树为例,当我们确定了要使用决策树算法的时候,为了能够更好地拟合和预测,我们需要调整它的参数。在决策树算法中,我们通常选择的参数是决策树的最大深度。于是我们会给出一系列的最大深度的值,比如{‘max_depth’:[1,2,3,4,5]},我们会尽可能包含最优最大深度。不过,我们如何知道哪一个最大深度的模型是最好的呢?我们需要一…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺

网格搜索算法是一种通过遍历给定的参数组合来优化模型表现的方法。

以决策树为例,当我们确定了要使用决策树算法的时候,为了能够更好地拟合和预测,我们需要调整它的参数。在决策树算法中,我们通常选择的参数是决策树的最大深度。

于是我们会给出一系列的最大深度的值,比如 {‘max_depth’: [1,2,3,4,5]},我们会尽可能包含最优最大深度。

不过,我们如何知道哪一个最大深度的模型是最好的呢?我们需要一种可靠的评分方法,对每个最大深度的决策树模型都进行评分,这其中非常经典的一种方法就是交叉验证,下面我们就以K折交叉验证为例,详细介绍它的算法过程。

首先我们先看一下数据集是如何分割的。我们拿到的原始数据集首先会按照一定的比例划分成训练集和测试集。比如下图,以8:2分割的数据集:

kfold_机器学习gridsearchcv(网格搜索)和kfold validation(k折验证)

训练集用来训练我们的模型,它的作用就像我们平时做的练习题;测试集用来评估我们训练好的模型表现如何,它的作用像我们做的高考题,这是要绝对保密不能提前被模型看到的。

因此,在K折交叉验证中,我们用到的数据是训练集中的所有数据。我们将训练集的所有数据平均划分成K份(通常选择K=10),取第K份作为验证集,它的作用就像我们用来估计高考分数的模拟题,余下的K-1份作为交叉验证的训练集。

对于我们最开始选择的决策树的5个最大深度 ,以 max_depth=1 为例,我们先用第2-10份数据作为训练集训练模型,用第1份数据作为验证集对这次训练的模型进行评分,得到第一个分数;然后重新构建一个 max_depth=1 的决策树,用第1和3-10份数据作为训练集训练模型,用第2份数据作为验证集对这次训练的模型进行评分,得到第二个分数……以此类推,最后构建一个 max_depth=1 的决策树用第1-9份数据作为训练集训练模型,用第10份数据作为验证集对这次训练的模型进行评分,得到第十个分数。于是对于 max_depth=1 的决策树模型,我们训练了10次,验证了10次,得到了10个验证分数,然后计算这10个验证分数的平均分数,就是 max_depth=1 的决策树模型的最终验证分数。

kfold_机器学习gridsearchcv(网格搜索)和kfold validation(k折验证)

对于 max_depth = 2,3,4,5 时,分别进行和 max_depth=1 相同的交叉验证过程,得到它们的最终验证分数。然后我们就可以对这5个最大深度的决策树的最终验证分数进行比较,分数最高的那一个就是最优最大深度,我们利用最优参数在全部训练集上训练一个新的模型,整个模型就是最优模型。

下面提供一个简单的利用决策树预测乳腺癌的例子:

from sklearn.model_selection import GridSearchCV, KFold, train_test_split

from sklearn.metrics import make_scorer, accuracy_score

from sklearn.tree import DecisionTreeClassifier

from sklearn.datasets import load_breast_cancer

data = load_breast_cancer()

X_train, X_test, y_train, y_test = train_test_split(

data[‘data’], data[‘target’], train_size=0.8, random_state=0)

regressor = DecisionTreeClassifier(random_state=0)

parameters = {
‘max_depth’: range(1, 6)}

scoring_fnc = make_scorer(accuracy_score)

kfold = KFold(n_splits=10)

grid = GridSearchCV(regressor, parameters, scoring_fnc, cv=kfold)

grid = grid.fit(X_train, y_train)

reg = grid.best_estimator_

print(‘best score: %f%grid.best_score_)

print(‘best parameters:’)

for key in parameters.keys():

print(%s: %d%(key, reg.get_params()[key]))

print(‘test score: %f%reg.score(X_test, y_test))

import pandas as pd

pd.DataFrame(grid.cv_results_).T

直接用决策树得到的分数大约是92%,经过网格搜索优化以后,我们可以在测试集得到95.6%的准确率:

best score: 0.938462

best parameters:

max_depth: 4

test score: 0.956140

转载自https://zhuanlan.zhihu.com/p/25637642

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/191614.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • 7、常见面试口语提问问题汇总

    一、pleaseintroduceyourself. Goodmorning!Itisreallymyhonortohavethisopportunityforaninterview;IhopeIcanmakeagoodperformancetoday.I’mconfidentthatIcansucceed.NowIwil…

  • Ubuntu安装gcc4.1.2

    Ubuntu安装gcc4.1.2安装之前,系统中必须要有cc或者gcc等编译器,并且是可用的,或者用环境变量CC指定系统上的编译器。如果系统上没有编译器,不能安装源代码形式的GCC4.1.2。如果是这种情况,可以在网上找一个与你系统相适应的如RPM等二进制形式的GCC软件包来安装使用。本文介绍的是以源代码形式提供的GCC软件包的安装过程,软件包本身和其安装过程同样适用于其它Linux和Unix系统。系统上原来的GCC编译…

  • 逻辑回归(Logistic Regression)详解

    逻辑回归(Logistic Regression)详解逻辑回归也称作logistic回归分析,是一种广义的线性回归分析模型,属于机器学习中的监督学习。其推导过程与计算方式类似于回归的过程,但实际上主要是用来解决二分类问题(也可以解决多分类问题)。通过给定的n组数据(训练集)来训练模型,并在训练结束后对给定的一组或多组数据(测试集)进行分类。其中每一组数据都是由p个指标构成。(1)逻辑回归所处理的数据逻辑回归是用来进行分类的。例如,我们给出一个人的[身高,体重]这两个指标,然后判断这个人是属于”胖“还是”瘦“这一类。对于这个问题,我们可以先测量n个

    2022年10月25日
  • 学习 Web 开发技术的16个最佳教程网站和博客

    学习 Web 开发技术的16个最佳教程网站和博客互联网经过这么多年的发展,已经出现了众多的Web开发技术,像.Net/Java/PHP/Python/Ruby等等。对于Web开发人员来说,不管是初学者还是有一定经验的开发人员都需要时刻学

  • 谈谈怎么实现Oracle数据库分区表「建议收藏」

    谈谈怎么实现Oracle数据库分区表「建议收藏」Oracle数据库分区是作为Oracle数据库性能优化的一种重要的手段和方法,做手头的项目以前,只聆听过分区的大名,感觉特神秘,看见某某高手在讨论会上夸夸其谈时,真是骂自己学艺不精,最近作GPS方面的项目,处理的数据量达到了几十GB,为了满足系统的实时性要求,必须提高数据的查询效率,这样就必须通过分区,以解燃眉之急!先说说分区的好处吧!1) 增强可用性:如果表的某个分区出现故障,表在其他分

  • linux安装pycharm专业版_linux下pycharm使用

    linux安装pycharm专业版_linux下pycharm使用文件准备流程下载pycharm的linux版本的软件包,下载地址:http://www.jetbrains.com/pycharm/download/#section=linux解压$tar-xfpycharm-professional-2017.1.4.tar.gz进入解压后的文件夹下的bin目录,执行sudoshpycharm.sh在安装过程中选择激活码激活注

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号