使用KNN识别MNIST手写数据集(手写,不使用KNeighborsClassifier)

KNN识别MNIST手写数据集(32*32维),根据KNN原理一步步实现。

大家好,又见面了,我是你们的朋友全栈君。

数据集
提取码:mrfr

浏览本文前请先搞懂K近邻的基本原理:最简单的分类算法之一:KNN(原理解析+代码实现)

算法实现步骤:

  1. 数据处理。每一个数字都是一个32X32维的数据,如下所示:
    在这里插入图片描述
    knn中邻居一词指的就是距离相近。我们要想计算两个样本之间的距离,就必须将每一个数字变成一个向量。具体做法就是将32X32的数据每一行接在一起,形成一个1X1024的数据,这样我们就可以计算欧式距离。
  2. 计算测试数据到所有训练数据的距离,并按照从小到大排序,选出前K个
  3. 根据距离计算前K个样本的权重
  4. 将相同的训练样本的权重加起来,返回权重最大样本的标签

代码实现:

import os
def load_data(path):
check = [i for i in range(10)]
final_data = []
for i in range(10):
final_data.append([])
files = os.listdir(path)    #文件夹
for file in files:
data = open(path + "/" + file)
str = ""      #将所有数据接在一起
temp = []
for line in data.readlines():
str = str + line[:-1]   #去掉回车,一行接一行
for i in str:
temp.append(int(i))   #变成数字
final_data[check.index(int(file[0]))].append(temp)   #根据标签放在列表相应的位置
return final_data, len(files)
def knn_mnist(K,test_data):
train_data, length = load_data('manifold/digits/trainingDigits')
distance = []     #存储测试数据到所有训练数据的距离
for i in range(len(train_data)):
for j in range(len(train_data[i])):
res = 0
for k in range(len(test_data)):
res += (test_data[k]-train_data[i][j][k]) ** 2   #欧氏距离
distance.append([res ** 0.5, i])   #距离+训练集数据标签
distance = sorted(distance, key=(lambda x: x[0]))  #按距离从小到大排序
weight = []   #权重与序号
sum_distance = 0.0
for i in range(K):
sum_distance += distance[i][0]   #计算前K个距离的和
for i in range(K):
weight.append([1 - distance[i][0] / sum_distance, distance[i][1]])  #权重+序号
#将相同序号的加起来
num = []   #统计有哪些序号
for i in range(K):
num.append(weight[i][1])
num = list(set(num))   #去重
final_res = []
for i in range(len(num)):
res = 0.0
for j in range(len(weight)):
if weight[j][1] == num[i]:   #前K个标签一样的样本权值加起来
res += weight[j][0]
final_res.append([res, num[i]])
final_res = sorted(final_res, key=(lambda x: x[0]),reverse=True)  # 按照权重从大到小排序
return final_res[0][1]   #最终返回最大权值对应的标签
def test():
K = 5
test_data, length = load_data('manifold/digits/testDigits')
#测试
for i in range(len(test_data)):
for j in range(len(test_data[i])):
print(knn_mnist(K, test_data[i][j]))
if __name__ == '__main__':
test()

  欢迎大家关注我的微信公众号:KI的算法杂记,有什么问题可以直接发私信。

在这里插入图片描述

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/125683.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)
blank

相关推荐

  • Python中字符串String去除出换行符(\n,\r)和空格的问题「建议收藏」

    Python中字符串String去除出换行符(\n,\r)和空格的问题「建议收藏」Python中字符串String去除出换行符和空格的问题(\n,\r)在Python的编写过程中,获取到的字符串进场存在不明原因的换行和空格,如何整合成一个单句,成为问题。方法:一、去除空格“·”代表的为空格  strip()"···xyz···".strip()#returns"xyz""···xyz···".lstrip()…

  • python2 nonlocal_python unboundlocalerror

    python2 nonlocal_python unboundlocalerror在廖雪峰的官网上看到一个很有意思题目。关于闭包的,有兴趣的朋友可以看一下这里,做一下这个题目,当然需要一点闭包的知识。下面我简述一下:利用闭包返回一个计数器函数,每次调用它返回递增整数。#修改下面这个函数defcreateCounter():defcounter():passreturncounter#测试:counterA=createCounter()print(counter…

  • 究竟什么是推荐?

    究竟什么是推荐?

  • md5 java 实现_MD5加密的Java实现

    md5 java 实现_MD5加密的Java实现在各种应用系统中,如果需要设置账户,那么就会涉及到储存用户账户信息的问题,为了保证所储存账户信息的安全,通常会采用MD5加密的方式来,进行储存。首先,简单得介绍一下,什么是MD5加密。MD5的全称是Message-DigestAlgorithm5(信息-摘要算法),在90年代初由MITLaboratoryforComputerScience和RSADataSecurityInc的…

  • 退出卸载360、QAX 天擎,无需密码

    退出卸载360、QAX 天擎,无需密码退出卸载360、QAX天擎,无需密码天擎企业版,退出和卸载是需要管理员密码进入360天擎,点击设置=>防护中心=>自我保护功能,去掉勾选,确认\360Safe\EntClient\conf\EntBase.dat目录目录查找:在桌面右下角找到天擎应用程序右键,进入程序安装位置,即可进入安装目录EntBase.dat文件删除uipassqtpass两行=后面的两个字符串即可关闭退出卸载密码。这里如果没有权限无法修改,先将EntBase.dat复制到其他文件夹(

  • Python建立数据库

    Python建立数据库Python建立数据库所谓数据库,即存储数据的仓库。每一个数据库可以存放若干个数据表,这里的数据表就是我们通常所说的二维表,分为行和列,每一行称为一条记录,每一列称为一个字段。表中的列是固定的,可变的是行。要注意,我们通常在列中指定数据的类型,在行中添加数据,即我们每次添加一条记录,就添加一行,而不是添加一列。对数据库的操作可以概括为就是向数据库中添加、删除、修改和查询数据,其中查询功能最为复杂。检查数据库是否存在你可以通过使用“SHOWDATABASES”语句列出系统中所有数据库,检查数据库是否存

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号