sklearn常用聚类算法模型【KMeans、DBSCAN】实践

sklearn常用聚类算法模型【KMeans、DBSCAN】实践聚类算法是很重要的一类算法模型,在实际的应用实践中是会经常使用到的,最近的工作类型中大多偏向于有监督学习类型模型的使用,而对于无监督算法模型的使用则使用得相对少了很多,今天就简单的回归一下聚类算法模型,主要是KMeans模型和DBSACN模型的使用。这两种模型可以说是聚类算法领域里面很具有代表性的算法了,前者是基于样本之间距离的聚类,后者是基于样本集密度的聚类。殊途同…

大家好,又见面了,我是你们的朋友全栈君。

       聚类算法是很重要的一类算法模型,在实际的应用实践中是会经常使用到的,最近的工作类型中大多偏向于有监督学习类型模型的使用,而对于无监督算法模型的使用则使用得相对少了很多,今天就简单的回归一下聚类算法模型,主要是KMeans模型和DBSACN模型的使用。

      这两种模型可以说是聚类算法领域里面很具有代表性的算法了,前者是基于样本之间距离的聚类,后者是基于样本集密度的聚类。殊途同归,本二者的本质都是为了最终实现:簇间距离最大,簇内距离最小的目的。

       本来想好好介绍一番原理,可是忽然觉得想着想着就越想越多了,这里索性就不再讲解原理了,网上也有很多系列的原理讲解的文章,可以去看看的,这里主要是实践使用分析为主。

        使用聚类算法,首先我们要有数据集才可以,这里为了简单,直接使用的是sklearn提供的数据集生成模块,来直接生成我们所需要的数据集,具体实现如下:

def getClusterData(flag=True,ns=1000,nf=2,centers=[[-1,-1],[1,1],[2,2]],cluster_std=[0.4,0.5,0.2]):
    '''
    得到回归数据
    centers(簇中心的个数或者自定义的簇中心)
    cluster_std(簇数据方差代表簇的聚合程度)
    '''
    if flag:
        cluster_X,cluster_y=datasets.make_circles(n_samples=ns,factor=.6,noise=.05)
    else:
        cluster_X,cluster_y=datasets.make_blobs(n_samples=ns,n_features=nf,centers=centers,
                                                cluster_std=cluster_std,random_state=9)
    return cluster_X,cluster_y

       数据集生成的代码块中已经加入了我的注释,相信是比较容易看明白的。

       接下来,我们要对原始生成的数据集进行划分,生成训练集和测试集,具体实现方法如下:

def dataSplit(dataset,label,ratio=0.3):
    '''
    数据集分割-----训练集、测试集合
    '''
    try:
        X_train,X_test,y_train,y_test=train_test_split(dataset,label,test_size=ratio)
    except:
        dataset,label=np.array(dataset),np.array(label)
        X_train,X_test,y_train,y_test=train_test_split(dataset,label,test_size=ratio)
    print '--------------------------------split_data shape-----------------------------------'
    print len(X_train), len(y_train)
    print len(X_test), len(y_test)
    return X_train,X_test,y_train,y_test

       上述代码块实现了原始数据集的分割。

        之后,我们需要做一点模型持久化存储于加载使用的工作,这也是机器学习或者是深度学习里面很重要的组成部分了,因为当数据集体量增大的时候,每次使用模型都重复去训练模型的时间代价或者是计算代价都是很大的,所以这里要做好已训练完成模型的持久化工作,具体实现方式如下:

def saveModel(model,save_path="model.pkl"):
    '''
    模型持久化存储
    '''
    joblib.dump(model,save_path)
    print u"持久化存储完成!"


def loadModel(model_path="model.pkl"):
    '''
    加载保存本地的模型
    '''
    model=joblib.load(model_path)
    return model

          上述的代码块实现了训练完成模型的本地化存储于加载使用。

           完成上述全部工作后,就要开始模型的搭建使用了,具体如下:

def clusterModel(flag=True):
    '''
    Kmeans算法关键参数:
    n_clusters:数据集中类别数目

    DBSCAN算法关键参数:
    eps: DBSCAN算法参数,即我们的ϵ-邻域的距离阈值,和样本距离超过ϵ的样本点不在ϵ-邻域内
    min_samples: DBSCAN算法参数,即样本点要成为核心对象所需要的ϵ-邻域的样本数阈值
    '''
    X,y=getClusterData(flag=flag,ns=3000,nf=5,centers=[[-1,-1],[1,1],[2,2]],
                       cluster_std=[0.4,0.5,0.2])
    X_train,X_test,y_train,y_test=dataSplit(X,y,ratio=0.3)
    #绘图
    plt.figure(figsize=(16,8))
    #Kmeans模型
    model=KMeans(n_clusters=3,random_state=9)
    model.fit(X_train)
    y_pred=model.predict(X_test)
    plt.subplot(121)
    plt.scatter(X_test[:, 0], X_test[:, 1],c=y_pred)
    plt.title('KMeans Cluster Result')
    #DESCAN模型
    # 下面的程序报错:AttributeError: 'DBSCAN' object has no attribute 'predict'
    # model=DBSCAN(eps=0.1,min_samples=10)
    # model.fit(X_train)
    # y_pred=model.predict(X_test)
    # 改为这样形式的可以了
    y_pred=DBSCAN(eps=0.05,min_samples=10).fit_predict(X_test)
    plt.subplot(122)
    plt.scatter(X_test[:, 0], X_test[:, 1],c=y_pred)
    plt.title('DBSCAN Cluster Result')
    if flag:
        plt.savefig('circleData.png')
    else:
        plt.savefig('blobData.png')

       上述代码块实现了KMeans模型和DBSACN模型的构建、训练和使用,我们对测试集的预测结果进行了可视化分析具体如下所示:

Circle数据集模型结果:

sklearn常用聚类算法模型【KMeans、DBSCAN】实践

非Circle数据集模型结果:

sklearn常用聚类算法模型【KMeans、DBSCAN】实践

      整体来看,上述两个数据集KMeans的综合表现优于DBSACN模型,不过这个只是一个简单的实验说明,就是为了熟练一下这两种常用聚类模型的使用,记录学习一下。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/152335.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • java wifi开发教程_WiFi技术「建议收藏」

    java wifi开发教程_WiFi技术「建议收藏」WiFi技术啥MU-MIMO:多用户-多输入多输出2x2MIMO:2个输入in,2个输出out。?一根天线既做接收也做发送,2×2即两根天线,同理4×4即为4根天线?RedmiAX6共6根天线,4根作为5G天线,2根作为2.4G天线RedmiAX6支持2x2160MHz及4x480MHz两种MU-MIMO工作模式802.11对照表协议频宽(MHz)单天线速率(Mbps…

  • phy芯片与rj45接法_232接口详细接线图

    phy芯片与rj45接法_232接口详细接线图千兆PHY通过网络变压器连接到RJ45接口,一共有4对差分线MDI[0..3]+/-。一般的接法是: MDI[0]+:RJ45[1] MDI[0]-:RJ45[2] MDI[1]+:RJ45[3] MDI[1]-:RJ45[6] MDI[2]+:RJ45[4] MDI[2]-:RJ45[5] MDI[3]+:RJ45[7]

  • js中将json字符串转换成json对象_字符串零终止符

    js中将json字符串转换成json对象_字符串零终止符今天遇到一个奇怪的问题,解析二维码后获得了一个JSON字符串,将JSON字符串转换成JSON对象的时候报错了。报错如下:代码如下:检查了无数次数据,数据是JSON字符串,引号也都是英文的,就是莫名其妙的转换不了。最后无奈了,终于找到一个解决办法,不用JSON.parse(xx)转换,用eval(‘(‘+xx+’)’)方法转换,最终解决了这个问题,虽然我还是不明白为什么JSON.parse转换会报错,有知道原因的大神吗?解决方法:数据如下:language{“ID”:”98-FA-9B

  • pycharm git使用_pycharm上传github

    pycharm git使用_pycharm上传githubpycharm操作git一、git安装和使用​ 安装操作:https://www.cnblogs.com/ximiaomiao/p/7140456.html1.如何使用git将本地代码上传到远程仓库初始化gitinit查看当前仓库状态gitstatus将项目的文件添加到仓库中gitadd<文件名>gitadd.(上传所有文件)将add的文件commit到仓库gitcommit-m”备注”将本地仓库关联到远程仓库gi

  • Android进程间通信(IPC)机制Binder简介和学习计划

    Android进程间通信(IPC)机制Binder简介和学习计划

    2021年12月31日
  • 网站前端性能优化

    继前面几篇文章后再来说说老生常谈的话题,怎么样提升前端性能。文中很多取材自网络及《HighPerformanceWebSites》,并根据自己工作中所接触到的知识整理而成。http://hov

    2021年12月24日

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号