opencv em算法,GMM模型

 

本章内容


      1. 获取颜色特征集合
      2. GMM聚类
      3. 显示聚类结果

opencv em算法,GMM模型

输出结果:

opencv em算法,GMM模型

源码

//#include <QCoreApplication>
#include <iostream>
#include <opencv2/opencv.hpp>
#include <opencv2/xfeatures2d.hpp>

int main(int argc, char *argv[])
{
    /* 本章内容
      1. 获取颜色特征集合
      2. GMM聚类
      3. 显示聚类结果
    */
    std::string fileName = "/home/wang/Image/toux.jpg";
    cv::Mat src = cv::imread(fileName);
    if(src.data == NULL){
        std::cout << "图片打开失败" << std::endl;
        return -1;
    }
    cv::imshow("src", src);
    int sampleNum = src.rows*src.cols;
    cv::Mat points(sampleNum,src.channels(),CV_32F,cv::Scalar(0));
    std::cout << "数据集维度:" << points.size() << std::endl;

    //1. rgb为特征集合
    for(int i=0;i<src.rows;i++){
        for(int j=0;j<src.cols;j++){
            cv::Vec3b bgr = src.at<cv::Vec3b>(i,j);
            int index =  i*src.cols + j;
            points.at<float>(index,0) = bgr[0];
            points.at<float>(index,1) = bgr[1];
            points.at<float>(index,2) = bgr[2];
        }
    }


    /* 2. GMM均值聚类
    api接口:
  */
    int clusterNum = 5;
    cv::Ptr<cv::ml::EM> EMModel = cv::ml::EM::create();
    EMModel->setClustersNumber(clusterNum); // 设置聚类数目
    EMModel->setCovarianceMatrixType(cv::ml::EM::COV_MAT_SPHERICAL); //设置协方差类型
    EMModel->setTermCriteria(cv::TermCriteria(cv::TermCriteria::EPS + cv::TermCriteria::COUNT,10,0.1));
    cv::Mat labels;
    /*
     * api接口:CV_WRAP virtual bool trainEM(InputArray samples,
                         OutputArray logLikelihoods=noArray(),
                         OutputArray labels=noArray(),
                         OutputArray probs=noArray()) = 0;
     * 参数分析:
        samples: 输入的样本,一个单通道的矩阵。从这个样本中,进行高斯混和模型估计。
        logLikelihoods: 可选项,输出一个矩阵,里面包含每个样本的似然对数值。
        labels: 可选项,输出每个样本对应的标注。
        probs: 可选项,输出一个矩阵,里面包含每个隐性变量的后验概率
    */
    EMModel->trainEM(points,cv::noArray(),labels,cv::noArray());
    // 3.显示图像分割结果
    std::vector<cv::Vec3b> colorSet;
    cv::RNG rng(12345);
    for(int i=0; i < clusterNum; i++) colorSet.push_back(cv::Vec3b(rng.uniform(0,255),
          rng.uniform(0,255),rng.uniform(0,255)));
    cv::Mat result = cv::Mat::zeros(src.rows,src.cols,CV_8UC3);
    // RGB 数据转换到样本数据
    for(int i=0;i<src.rows;i++){
        for(int j=0;j<src.cols;j++){
            int index = i*src.cols + j;
            int label = labels.at<int>(index);
            result.at<cv::Vec3b>(i, j) = colorSet[label];
      }
    }

    cv::imshow("cluster result", result);

    // 绘制中心

    cv::waitKey(0);
    return 1;
}