opencv+caffe 基于CNN的年龄性别检测实例demo(c++)
环境配置
visual stdio2017配置openCV
CNN使用caffe的age_net.caffemodel框架
原理讲解
基于2015 CVPR 的论文,《Age and Gender Classification using Convolutional Neural Networks》
参考博客:https://blog.****.net/qq_14845119/article/details/52454539
一些思考:
当时是作为一个小的demo展示。
关于论文老师问了我一些问题:
- 这篇文章里为什么用的是八分类,不是五分类,十分类?
- 该论文是在audience上数据集上测试,如今已经解决了很多包括光照和遮挡等限制问题,那么如何继续提升精度?我的回答是 增大数据集or对图片进行预处理,老师说说的太宽泛了,细致的我也答不上来emmmmm
代码详解
#include <opencv2/opencv.hpp>
#include <opencv2/dnn.hpp>
#include <iostream>
using namespace cv;
using namespace cv::dnn;
using namespace std;
//引入模型文件(模型文件+描述文件)
//引用haar模型检测人脸
String haar_file = "G:/opencv/build/etc/haarcascades/haarcascade_frontalface_alt_tree.xml";
String age_model = "G:/VCprojects/DNN/age_gender/age_net.caffemodel";
String age_text = "G:/VCprojects/DNN/age_gender/deploy_age.prototxt";
String gender_model = "G:/VCprojects/DNN/age_gender/gender_net.caffemodel";
String gender_text = "G:/VCprojects/DNN/age_gender/deploy_gender.prototxt";
//定义方法
void predict_age(Net &net, Mat &image);
void predict_gender(Net &net, Mat &image);
int main(int argc, char** argv) {
//使用.prototxt和.caffemodel文件的路径读取并初始化网络
Net age_net = readNetFromCaffe(age_text, age_model);
Net gender_net = readNetFromCaffe(gender_text, gender_model);
//检查网络是否已成功读取
if (age_net.empty())
{
std::cerr << "Can't load age network by using the following files: " << std::endl;
std::cerr << "prototxt: " << age_text << std::endl;
std::cerr << "caffemodel: " << age_model << std::endl;
std::cerr << "age_net.caffemodel can be downloaded here:" << std::endl;
exit(-1);
}
if (gender_net.empty())
{
std::cerr << "Can't load gender network by using the following files: " << std::endl;
std::cerr << "prototxt: " << gender_text << std::endl;
std::cerr << "caffemodel: " << gender_model << std::endl;
std::cerr << "gender_net.caffemodel can be downloaded here:" << std::endl;
exit(-1);
}
Mat src = imread("G:/VCprojects/DNN/age_gender/face00.png");
//Mat src = imread("G:/VCprojects/DNN/age_gender/three.png");
//Mat src = imread("G:/VCprojects/DNN/age_gender/9.jpg");
if (src.empty()) {
printf("could not load image...\n");
return -1;
}
namedWindow("input", CV_WINDOW_AUTOSIZE);
imshow("input", src);
//创建一个级联检测器
CascadeClassifier detector;
//load文件
detector.load(haar_file);
//vector存放人脸检测的结果
vector<Rect> faces;
Mat gray;
//变成一张灰度图像
cvtColor(src, gray, COLOR_BGR2GRAY);
//检测人脸
detector.detectMultiScale(gray, faces, 1.02, 1, 0, Size(40, 40), Size(200, 200));
//找到人脸的话
for (size_t t = 0; t < faces.size(); t++) {
//找到的人脸在他上面绘制一个矩形
rectangle(src, faces[t], Scalar(0, 255, 255), 2, 8, 0);
//预测年龄和性别
predict_age(age_net, src(faces[t]));
predict_gender(age_net, src(faces[t]));
}
imshow("age-gender-prediction-demo", src);
waitKey(0);
return 0;
}
//年龄的8个分类
vector<String> ageLabels() {
vector<String> ages;
ages.push_back("0-2");
ages.push_back("4 - 6");
ages.push_back("8 - 13");
ages.push_back("15 - 20");
ages.push_back("25 - 32");
ages.push_back("38 - 43");
ages.push_back("48 - 53");
ages.push_back("60-");
return ages;
}
//预测年龄
void predict_age(Net &net, Mat &image) {
//输入
//读取输入图像并转换为Blob,可由Net接受
Mat blob = blobFromImage(image, 1.0, Size(227, 227));
//将blob传递到网络
net.setInput(blob, "data");
//Make forward pass 向前传播
Mat prob = net.forward("prob");
//一维一行,做8个年龄分类
Mat probMat = prob.reshape(1, 1);
//求max_location
Point classNum;
double classProb;
vector<String> ages = ageLabels();
//求最大可能性
minMaxLoc(probMat, NULL, &classProb, NULL, &classNum);
//获取年龄
int classidx = classNum.x;
//打印结果
putText(image, format("age:%s", ages.at(classidx).c_str()), Point(2, 10), FONT_HERSHEY_PLAIN, 0.8, Scalar(0, 0, 255), 1);
std::cout << "Best class at age: #" << classidx << " '" << ages.at(classidx) << "'" << std::endl;
std::cout << "Probability: " << classProb * 100 << "%" << std::endl;
}
void predict_gender(Net &net, Mat &image) {
// 输入
Mat blob = blobFromImage(image, 1.0, Size(227, 227));
net.setInput(blob, "data");
//预测分类
Mat prob = net.forward("prob");
Mat probMat = prob.reshape(1, 1);
//保证要么是male 要么是female,promat。at(0,0)或者(0,1)
putText(image, format("gender:%s", (probMat.at<float>(0, 0) > probMat.at<float>(0, 1) ? "M" : "F")),
Point(2, 20), FONT_HERSHEY_PLAIN, 0.8, Scalar(0, 0, 255), 1);
}