C++版OpenCV使用神经网络ANN进行mnist手写数字识别[通俗易懂]

C++版OpenCV使用神经网络ANN进行mnist手写数字识别[通俗易懂]说起神经网络,很多人以为只有Keras或者tensorflow才支持,其实OpenCV也支持神经网络的,下面就使用OpenCV的神经网络进行手写数字识别,训练10次的准确率就高达96%。环境准备:vs2015OpenCV4.5.0以下为ANN神经网络的训练代码:#include<iostream>#include<opencv.hpp>#include<string>#include<fstream>usingnamespacestd

大家好,又见面了,我是你们的朋友全栈君。

说起神经网络,很多人以为只有Keras或者tensorflow才支持,其实OpenCV也支持神经网络的,下面就使用OpenCV的神经网络进行手写数字识别,训练10次的准确率就高达96%。
环境准备:
vs2015
OpenCV4.5.0
以下为ANN神经网络的训练代码:

#include<iostream>
#include<opencv.hpp>
#include <string>
#include <fstream>
using namespace std;
using namespace cv;
using namespace cv::ml;
//小端存储转换
int reverseInt(int i);
//读取image数据集信息
Mat read_mnist_image(const string fileName);
//读取label数据集信息
Mat read_mnist_label(const string fileName);
//将标签数据改为one-hot型
Mat one_hot(Mat label, int classes_num);
string train_images_path = "G:/vs2015_opencv_ml/mnist/train-images.idx3-ubyte";
string train_labels_path = "G:/vs2015_opencv_ml/mnist/train-labels.idx1-ubyte";
string test_images_path = "G:/vs2015_opencv_ml/mnist/t10k-images.idx3-ubyte";
string test_labels_path = "G:/vs2015_opencv_ml/mnist/t10k-labels.idx1-ubyte";
int main()
{ 

/* ---------第一部分:训练数据准备----------- */
//读取训练标签数据 (60000,1) 类型为int32
Mat train_labels = read_mnist_label(train_labels_path);
//ann神经网络的标签数据需要转为one-hot型
train_labels = one_hot(train_labels, 10);
//读取训练图像数据 (60000,784) 类型为float32 数据未归一化
Mat train_images = read_mnist_image(train_images_path);
//将图像数据归一化
train_images = train_images / 255.0;
//读取测试数据标签(10000,1) 类型为int32 测试标签不用转为one-hot型
Mat test_labels = read_mnist_label(test_labels_path);
//读取测试数据图像 (10000,784) 类型为float32 数据未归一化
Mat test_images = read_mnist_image(test_images_path);
//归一化
test_images = test_images / 255.0;
/* ---------第二部分:构建ann训练模型并进行训练----------- */
cv::Ptr<cv::ml::ANN_MLP> ann = cv::ml::ANN_MLP::create();
//定义模型的层次结构 输入层为784 隐藏层为64 输出层为10
Mat layerSizes = (Mat_<int>(1, 3) << 784, 64, 10);
ann->setLayerSizes(layerSizes);
//设置参数更新为误差反向传播法
ann->setTrainMethod(ANN_MLP::BACKPROP, 0.001, 0.1);
//设置激活函数为sigmoid
ann->setActivationFunction(ANN_MLP::SIGMOID_SYM, 1.0, 1.0);
//设置跌打条件 最大训练次数为100
ann->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER | TermCriteria::EPS, 10, 0.0001));
//开始训练
cv::Ptr<cv::ml::TrainData> train_data = cv::ml::TrainData::create(train_images, cv::ml::ROW_SAMPLE,train_labels);
cout << "开始进行训练..." << endl;
ann->train(train_data);
cout << "训练完成" << endl;
/* ---------第三部分:在测试数据集上预测计算准确率----------- */
Mat pre_out;
//返回值为第一个图像的预测值 pre_out为整个batch的预测值集合
cout << "开始进行预测..." << endl;
float ret = ann->predict(test_images, pre_out);
cout << "预测完成" << endl;
//计算准确率
int equal_nums = 0;
for (int i = 0; i < pre_out.rows; i++)
{ 

//获取每一个结果的最大值所在下标
Mat temp = pre_out.rowRange(i, i + 1);
double maxVal = 0;
cv::Point maxPoint;
cv::minMaxLoc(temp,NULL, &maxVal,NULL, &maxPoint);
int max_index = maxPoint.x;
int test_index = test_labels.at<int32_t>(i, 0);
if (max_index == test_index)
{ 

equal_nums++;
}
}
float acc = float(equal_nums) / float(pre_out.rows);
cout << "测试数据集上的准确率为:" << acc * 100 << "%" << endl;
//保存模型
ann->save("mnist_ann.xml");
getchar();
return 0;
}
;
int reverseInt(int i) { 

unsigned char c1, c2, c3, c4;
c1 = i & 255;
c2 = (i >> 8) & 255;
c3 = (i >> 16) & 255;
c4 = (i >> 24) & 255;
return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4;
}
Mat read_mnist_image(const string fileName) { 

int magic_number = 0;
int number_of_images = 0;
int n_rows = 0;
int n_cols = 0;
Mat DataMat;
ifstream file(fileName, ios::binary);
if (file.is_open())
{ 

cout << "成功打开图像集 ..." << endl;
file.read((char*)&magic_number, sizeof(magic_number));//幻数(文件格式)
file.read((char*)&number_of_images, sizeof(number_of_images));//图像总数
file.read((char*)&n_rows, sizeof(n_rows));//每个图像的行数
file.read((char*)&n_cols, sizeof(n_cols));//每个图像的列数
magic_number = reverseInt(magic_number);
number_of_images = reverseInt(number_of_images);
n_rows = reverseInt(n_rows);
n_cols = reverseInt(n_cols);
cout << "幻数(文件格式):" << magic_number
<< " 图像总数:" << number_of_images
<< " 每个图像的行数:" << n_rows
<< " 每个图像的列数:" << n_cols << endl;
cout << "开始读取Image数据......" << endl;
DataMat = Mat::zeros(number_of_images, n_rows * n_cols, CV_32FC1);
for (int i = 0; i < number_of_images; i++) { 

for (int j = 0; j < n_rows * n_cols; j++) { 

unsigned char temp = 0;
file.read((char*)&temp, sizeof(temp));
//可以在下面这一步将每个像素值归一化
float pixel_value = float(temp);
//按照行将像素值一个个写入Mat中
DataMat.at<float>(i, j) = pixel_value;
}
}
cout << "读取Image数据完毕......" << endl;
}
file.close();
return DataMat;
}
Mat read_mnist_label(const string fileName) { 

int magic_number;
int number_of_items;
Mat LabelMat;
ifstream file(fileName, ios::binary);
if (file.is_open())
{ 

cout << "成功打开标签集 ... " << endl;
file.read((char*)&magic_number, sizeof(magic_number));
file.read((char*)&number_of_items, sizeof(number_of_items));
magic_number = reverseInt(magic_number);
number_of_items = reverseInt(number_of_items);
cout << "幻数(文件格式):" << magic_number << " ;标签总数:" << number_of_items << endl;
cout << "开始读取Label数据......" << endl;
//CV_32SC1代表32位有符号整型 通道数为1
LabelMat = Mat::zeros(number_of_items, 1, CV_32SC1);
for (int i = 0; i < number_of_items; i++) { 

unsigned char temp = 0;
file.read((char*)&temp, sizeof(temp));
LabelMat.at<unsigned int>(i, 0) = (unsigned int)temp;
}
cout << "读取Label数据完毕......" << endl;
}
file.close();
return LabelMat;
}
//将标签数据改为one-hot型
Mat one_hot(Mat label, int classes_num)
{ 

//[2]->[0 1 0 0 0 0 0 0 0 0]
int rows = label.rows;
Mat one_hot = Mat::zeros(rows, classes_num, CV_32FC1);
for (int i = 0; i < label.rows; i++)
{ 

int index = label.at<int32_t>(i, 0);
one_hot.at<float>(i, index) = 1.0;
}
return one_hot;
}

执行代码,训练结果如下:

成功打开标签集 ...
幻数(文件格式):2049  ;标签总数:60000
开始读取Label数据......
读取Label数据完毕......
成功打开图像集 ...
幻数(文件格式):2051 图像总数:60000 每个图像的行数:28 每个图像的列数:28
开始读取Image数据......
读取Image数据完毕......
成功打开标签集 ...
幻数(文件格式):2049  ;标签总数:10000
开始读取Label数据......
读取Label数据完毕......
成功打开图像集 ...
幻数(文件格式):2051 图像总数:10000 每个图像的行数:28 每个图像的列数:28
开始读取Image数据......
读取Image数据完毕......
开始进行训练...
训练完成
开始进行预测...
预测完成
测试数据集上的准确率为:96.26%

从上可知,使用ANN神经网络仅仅训练10次,就可以达到96.24%的识别率,增大训练次数,这个识别率还会提高,而且ann的模型文件非常小,才一兆多一点,由此可知,ANN模型非常适合端上部署。
在这里插入图片描述
使用ann的模型文件识别OpenCV加载的手写数字图片,代码如下:

#include<iostream>
#include<opencv.hpp>
using namespace std;
using namespace cv;
int main()
{ 

//读取一张手写数字图片(28,28)
Mat image = cv::imread("shuzi1.jpg", 0);
Mat img_show = image.clone();
//更换数据类型有uchar->float32
image.convertTo(image, CV_32F);
//归一化
image = image / 255.0;
//(1,784)
image = image.reshape(1, 1);
//加载ann模型
cv::Ptr<cv::ml::ANN_MLP> ann= cv::ml::StatModel::load<cv::ml::ANN_MLP>("mnist_ann.xml");
//预测图片
Mat pre_out;
float ret = ann->predict(image,pre_out);
double maxVal = 0;
cv::Point maxPoint;
cv::minMaxLoc(pre_out, NULL, &maxVal, NULL, &maxPoint);
int max_index = maxPoint.x;
cout << "图像上的数字为:" << max_index << " 置信度为:" << maxVal << endl;
cv::imshow("img", img_show);
cv::waitKey(0);
getchar();
return 0;
}

执行以上代码,结果如下:
在这里插入图片描述
由此可见,使用该ANN模型能正确识别手写数字,并且ANN模型由于保存的是权重参数,因此模型文件极小,非常适合在端上进行部署。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/147908.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)
blank

相关推荐

  • python 获取时间戳_datetime获取当前时间

    python 获取时间戳_datetime获取当前时间1、获取秒级、毫秒级和微秒级时间戳importtimeimportdatetimet=time.time()#当前时间print(t)#原始时间数据print(int(t))#秒级时间戳print(int(round(t*1000)))#毫秒级时间戳print(int(round(t*1000000)))#微秒级时间戳结果:1634191096.03610181634191096163419109603616341910960361

  • (怪盗基德的滑翔翼)(最长上升子序列)[通俗易懂]

    (怪盗基德的滑翔翼)(最长上升子序列)[通俗易懂]原题链接怪盗基德是一个充满传奇色彩的怪盗,专门以珠宝为目标的超级盗窃犯。而他最为突出的地方,就是他每次都能逃脱中村警部的重重围堵,而这也很大程度上是多亏了他随身携带的便于操作的滑翔翼。有一天,怪盗基德像往常一样偷走了一颗珍贵的钻石,不料却被柯南小朋友识破了伪装,而他的滑翔翼的动力装置也被柯南踢出的足球破坏了。不得已,怪盗基德只能操作受损的滑翔翼逃脱。假设城市中一共有N幢建筑排成一条线,每幢建筑的高度各不相同。初始时,怪盗基德可以在任何一幢建筑的顶端。他可以选择一个方向逃跑,但是不能中途改变方向

  • 近场动力学matlab程序_一阶惯性环节matlab

    近场动力学matlab程序_一阶惯性环节matlab本发明属于过程控制技术领域,尤其涉及一种镇定一阶惯性加纯滞后系统的线性自抗扰控制器设计方法,进一步涉及一种用于具有时滞的工业过程控制系统的自抗扰控制器设计方法。背景技术:时滞作为一种常见的物理现象,在工业过程和生产生活中随处可见,例如管道对油气的输送、线缆对信号的传递、锅炉的燃烧等过程。这一类过程具有的共性即被控量不能立即对控制量的作用做出反应,这样的特点决定了被控对象输入与输出之间不同步的开环特…

  • Java集合容器面试题(2020最新版)「建议收藏」

    Java集合容器面试题(2020最新版)「建议收藏」文章目录集合容器概述什么是集合集合的特点集合和数组的区别使用集合框架的好处常用的集合类有哪些?List,Set,Map三者的区别?List、Set、Map是否继承自Collection接口?List、Map、Set三个接口存取元素时,各有什么特点?集合框架底层数据结构哪些集合类是线程安全的?Java集合的快速失败机制“fail-fast”?怎么确保一个集合不能被修改?Collection…

  • Cocos发展Visual Studio下一个libcurl图书馆开发环境的搭建

    Cocos发展Visual Studio下一个libcurl图书馆开发环境的搭建

  • springboot启动原理总结_Springboot启动流程

    springboot启动原理总结_Springboot启动流程说明:我这里只说结果,和简单的代码,面试应该是够了,毕竟源码内容不是所有人都能记住的,如果要学习源码请看其他大佬的文章,写的比较详细,而且差不多都一样。背景:面试经常会问道springboot启动流程或者原理,看了多数博友的文章,都是大同小异,但是面试的时候不可能那么多,所以我将启动流程总结一下。启动流程:1.启动springboot这需要执行SpringApplication执行类即可2.执行的时候执行两个重要的代码,@springBootAppli…

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号