BP网络识别数字
文本仅对0-9这十个文件夹中sample_mun_perclass个样本进行训练,直接通过API函数FindFirstFile和FindNextFile得到目录下文件,不需要对图片名编号
用了一下午时间去调这个代码,所以还有很多不完善的地方,以后有时间再去完善,比如:
本文的测试图片仅仅是单张测试,如果要测试准确率,可以根据前面训练时批量读取图片的代码进行简单修改,即可进行批量测试。
参考链接:http://blog.csdn.net/qq_15947787/article/details/51384258
特征采用8*16的二值化图像构成的128维向量作为输入层,3层128维的隐藏层,10维的输出层。输出用{1,0,0,0,0,…0}{0,1,0,0,0,…0}{0,0,1,0,0,…0}……{0,0,0,0,0,…1}表示
测试图像来源:http://www.cnblogs.com/ronny/p/opencv_road_more_01.html
测试代码:
//opencv2.4.9 + vs2012 + 64位
#include <windows.h>
#include <iostream>
#include <opencv2/opencv.hpp>
using namespace cv;
using namespace std;
char* WcharToChar(const wchar_t* wp)
{
char *m_char;
int len= WideCharToMultiByte(CP_ACP,0,wp,wcslen(wp),NULL,0,NULL,NULL);
m_char=new char[len+1];
WideCharToMultiByte(CP_ACP,0,wp,wcslen(wp),m_char,len,NULL,NULL);
m_char[len]='\0';
return m_char;
}
wchar_t* CharToWchar(const char* c)
{
wchar_t *m_wchar;
int len = MultiByteToWideChar(CP_ACP,0,c,strlen(c),NULL,0);
m_wchar=new wchar_t[len+1];
MultiByteToWideChar(CP_ACP,0,c,strlen(c),m_wchar,len);
m_wchar[len]='\0';
return m_wchar;
}
wchar_t* StringToWchar(const string& s)
{
const char* p=s.c_str();
return CharToWchar(p);
}
int main()
{
const string fileform = "*.png";
const string perfileReadPath = "charSamples";
const int sample_mun_perclass = 20;//训练字符每类数量
const int class_mun = 10;//训练字符类数
const int image_cols = 8;
const int image_rows = 16;
string fileReadName,
fileReadPath;
char temp[256];
float trainingData[class_mun*sample_mun_perclass][image_rows*image_cols] = {{0}};//每一行一个训练样本
float labels[class_mun*sample_mun_perclass][class_mun]={{0}};//训练样本标签
for(int i=0;i<=class_mun-1;++i)//不同类
{
//读取每个类文件夹下所有图像
int j = 0;//每一类读取图像个数计数
sprintf(temp, "%d", i);
fileReadPath = perfileReadPath + "/" + temp + "/" + fileform;
cout<<"文件夹"<<i<<endl;
HANDLE hFile;
LPCTSTR lpFileName = StringToWchar(fileReadPath);//指定搜索目录和文件类型,如搜索d盘的音频文件可以是"D:\\*.mp3"
WIN32_FIND_DATA pNextInfo; //搜索得到的文件信息将储存在pNextInfo中;
hFile = FindFirstFile(lpFileName,&pNextInfo);//请注意是 &pNextInfo , 不是 pNextInfo;
if(hFile == INVALID_HANDLE_VALUE)
{
exit(-1);//搜索失败
}
//do-while循环读取
do
{
if(pNextInfo.cFileName[0] == '.')//过滤.和..
continue;
j++;//读取一张图
//wcout<<pNextInfo.cFileName<<endl;
printf("%s\n",WcharToChar(pNextInfo.cFileName));
//对读入的图片进行处理
Mat srcImage = imread( perfileReadPath + "/" + temp + "/" + WcharToChar(pNextInfo.cFileName),CV_LOAD_IMAGE_GRAYSCALE);
Mat resizeImage;
Mat trainImage;
resize(srcImage,resizeImage,Size(image_cols,image_rows),(0,0),(0,0),CV_INTER_AREA);//使用象素关系重采样。当图像缩小时候,该方法可以避免波纹出现
threshold(resizeImage,trainImage,0,255,CV_THRESH_BINARY|CV_THRESH_OTSU);
for(int k = 0; k<image_rows*image_cols; ++k)
{
trainingData[i*sample_mun_perclass+(j-1)][k] = (float)trainImage.data[k];
//trainingData[i*sample_mun_perclass+(j-1)][k] = (float)trainImage.at<unsigned char>((int)k/8,(int)k%8);//(float)train_image.data[k];
//cout<<trainingData[i*sample_mun_perclass+(j-1)][k] <<" "<< (float)trainImage.at<unsigned char>(k/8,k%8)<<endl;
}
} while (FindNextFile(hFile,&pNextInfo) && j<sample_mun_perclass);//如果设置读入的图片数量,则以设置的为准,如果图片不够,则读取文件夹下所有图片
}
// Set up training data Mat
Mat trainingDataMat(class_mun*sample_mun_perclass, image_rows*image_cols, CV_32FC1, trainingData);
cout<<"trainingDataMat——OK!"<<endl;
// Set up label data
for(int i=0;i<=class_mun-1;++i)
{
for(int j=0;j<=sample_mun_perclass-1;++j)
{
for(int k = 0;k<class_mun;++k)
{
if(k==i)
labels[i*sample_mun_perclass + j][k] = 1;
else labels[i*sample_mun_perclass + j][k] = 0;
}
}
}
Mat labelsMat(class_mun*sample_mun_perclass, class_mun, CV_32FC1,labels);
cout<<"labelsMat:"<<endl;
cout<<labelsMat<<endl;
cout<<"labelsMat——OK!"<<endl;
//训练代码
cout<<"training start...."<<endl;
CvANN_MLP bp;
// Set up BPNetwork's parameters
CvANN_MLP_TrainParams params;
params.train_method=CvANN_MLP_TrainParams::BACKPROP;
params.bp_dw_scale=0.001;
params.bp_moment_scale=0.1;
params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER|CV_TERMCRIT_EPS,10000,0.0001); //设置结束条件
//params.train_method=CvANN_MLP_TrainParams::RPROP;
//params.rp_dw0 = 0.1;
//params.rp_dw_plus = 1.2;
//params.rp_dw_minus = 0.5;
//params.rp_dw_min = FLT_EPSILON;
//params.rp_dw_max = 50.;
//Setup the BPNetwork
Mat layerSizes=(Mat_<int>(1,5) << image_rows*image_cols,128,128,128,class_mun);
bp.create(layerSizes,CvANN_MLP::SIGMOID_SYM,1.0,1.0);//CvANN_MLP::SIGMOID_SYM
//CvANN_MLP::GAUSSIAN
//CvANN_MLP::IDENTITY
cout<<"training...."<<endl;
bp.train(trainingDataMat, labelsMat, Mat(),Mat(), params);
bp.save("../bpcharModel.xml"); //save classifier
cout<<"training finish...bpModel1.xml saved "<<endl;
//测试神经网络
cout<<"测试:"<<endl;
Mat test_image = imread("test.png",CV_LOAD_IMAGE_GRAYSCALE);
Mat test_temp;
resize(test_image,test_temp,Size(image_cols,image_rows),(0,0),(0,0),CV_INTER_AREA);//使用象素关系重采样。当图像缩小时候,该方法可以避免波纹出现
threshold(test_temp,test_temp,0,255,CV_THRESH_BINARY|CV_THRESH_OTSU);
Mat_<float>sampleMat(1,image_rows*image_cols);
for(int i = 0; i<image_rows*image_cols; ++i)
{
sampleMat.at<float>(0,i) = (float)test_temp.at<uchar>(i/8,i%8);
}
Mat responseMat;
bp.predict(sampleMat,responseMat);
Point maxLoc;
double maxVal = 0;
minMaxLoc(responseMat,NULL,&maxVal,NULL,&maxLoc);
cout<<"识别结果:"<<maxLoc.x<<" 相似度:"<<maxVal*100<<"%"<<endl;
imshow("test_image",test_image);
waitKey(0);
return 0;
}
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
- 105
- 106
- 107
- 108
- 109
- 110
- 111
- 112
- 113
- 114
- 115
- 116
- 117
- 118
- 119
- 120
- 121
- 122
- 123
- 124
- 125
- 126
- 127
- 128
- 129
- 130
- 131
- 132
- 133
- 134
- 135
- 136
- 137
- 138
- 139
- 140
- 141
- 142
- 143
- 144
- 145
- 146
- 147
- 148
- 149
- 150
- 151
- 152
- 153
- 154
- 155
- 156
- 157
- 158
- 159
- 160
- 161
- 162
- 163
- 164
- 165
- 166
- 167
测试结果:
代码已打包上传:
http://download.csdn.net/detail/qq_15947787/9518680
5/13 优化trainingDataMat, labelsMat
转载自:https://blog.csdn.net/qq_15947787/article/details/51385861