C++最小二乘法拟合-(线性拟合和多项式拟合)

在进行曲线拟合时用的最多的是最小二乘法,其中以一元函数(线性)和多元函数(多项式)居多,下面这个类专门用于进行多项式拟合,可以根据用户输入的阶次进行多项式拟合,算法来自于网上,和GSL的拟合算法对比过,没有问题。此类在拟合完后还能计算拟合之后的误差:SSE(剩余平方和),SSR(回归平方和),RMSE(均方根误差),R-square(确定系数)。

 

1.fit类的实现

先看看fit类的代码:(只有一个头文件方便使用)

 

 
  1. #ifndef CZY_MATH_FIT

  2. #define CZY_MATH_FIT

  3. #include <vector>

  4. /*

  5. 尘中远,于2014.03.20

  6. 主页:http://blog.csdn.net/czyt1988/article/details/21743595

  7. 参考:http://blog.csdn.net/maozefa/article/details/1725535

  8. */

  9. namespace czy{

  10. ///

  11. /// \brief 曲线拟合类

  12. ///

  13. class Fit{

  14. std::vector<double> factor; ///<拟合后的方程系数

  15. double ssr; ///<回归平方和

  16. double sse; ///<(剩余平方和)

  17. double rmse; ///<RMSE均方根误差

  18. std::vector<double> fitedYs;///<存放拟合后的y值,在拟合时可设置为不保存节省内存

  19. public:

  20. Fit():ssr(0),sse(0),rmse(0){factor.resize(2,0);}

  21. ~Fit(){}

  22. ///

  23. /// \brief 直线拟合-一元回归,拟合的结果可以使用getFactor获取,或者使用getSlope获取斜率,getIntercept获取截距

  24. /// \param x 观察值的x

  25. /// \param y 观察值的y

  26. /// \param isSaveFitYs 拟合后的数据是否保存,默认否

  27. ///

  28. template<typename T>

  29. bool linearFit(const std::vector<typename T>& x, const std::vector<typename T>& y,bool isSaveFitYs=false)

  30. {

  31. return linearFit(&x[0],&y[0],getSeriesLength(x,y),isSaveFitYs);

  32. }

  33. template<typename T>

  34. bool linearFit(const T* x, const T* y,size_t length,bool isSaveFitYs=false)

  35. {

  36. factor.resize(2,0);

  37. typename T t1=0, t2=0, t3=0, t4=0;

  38. for(int i=0; i<length; ++i)

  39. {

  40. t1 += x[i]*x[i];

  41. t2 += x[i];

  42. t3 += x[i]*y[i];

  43. t4 += y[i];

  44. }

  45. factor[1] = (t3*length - t2*t4) / (t1*length - t2*t2);

  46. factor[0] = (t1*t4 - t2*t3) / (t1*length - t2*t2);

  47. //////////////////////////////////////////////////////////////////////////

  48. //计算误差

  49. calcError(x,y,length,this->ssr,this->sse,this->rmse,isSaveFitYs);

  50. return true;

  51. }

  52. ///

  53. /// \brief 多项式拟合,拟合y=a0+a1*x+a2*x^2+……+apoly_n*x^poly_n

  54. /// \param x 观察值的x

  55. /// \param y 观察值的y

  56. /// \param poly_n 期望拟合的阶数,若poly_n=2,则y=a0+a1*x+a2*x^2

  57. /// \param isSaveFitYs 拟合后的数据是否保存,默认是

  58. ///

  59. template<typename T>

  60. void polyfit(const std::vector<typename T>& x

  61. ,const std::vector<typename T>& y

  62. ,int poly_n

  63. ,bool isSaveFitYs=true)

  64. {

  65. polyfit(&x[0],&y[0],getSeriesLength(x,y),poly_n,isSaveFitYs);

  66. }

  67. template<typename T>

  68. void polyfit(const T* x,const T* y,size_t length,int poly_n,bool isSaveFitYs=true)

  69. {

  70. factor.resize(poly_n+1,0);

  71. int i,j;

  72. //double *tempx,*tempy,*sumxx,*sumxy,*ata;

  73. std::vector<double> tempx(length,1.0);

  74.  
  75. std::vector<double> tempy(y,y+length);

  76.  
  77. std::vector<double> sumxx(poly_n*2+1);

  78. std::vector<double> ata((poly_n+1)*(poly_n+1));

  79. std::vector<double> sumxy(poly_n+1);

  80. for (i=0;i<2*poly_n+1;i++){

  81. for (sumxx[i]=0,j=0;j<length;j++)

  82. {

  83. sumxx[i]+=tempx[j];

  84. tempx[j]*=x[j];

  85. }

  86. }

  87. for (i=0;i<poly_n+1;i++){

  88. for (sumxy[i]=0,j=0;j<length;j++)

  89. {

  90. sumxy[i]+=tempy[j];

  91. tempy[j]*=x[j];

  92. }

  93. }

  94. for (i=0;i<poly_n+1;i++)

  95. for (j=0;j<poly_n+1;j++)

  96. ata[i*(poly_n+1)+j]=sumxx[i+j];

  97. gauss_solve(poly_n+1,ata,factor,sumxy);

  98. //计算拟合后的数据并计算误差

  99. fitedYs.reserve(length);

  100. calcError(&x[0],&y[0],length,this->ssr,this->sse,this->rmse,isSaveFitYs);

  101.  
  102. }

  103. ///

  104. /// \brief 获取系数

  105. /// \param 存放系数的数组

  106. ///

  107. void getFactor(std::vector<double>& factor){factor = this->factor;}

  108. ///

  109. /// \brief 获取拟合方程对应的y值,前提是拟合时设置isSaveFitYs为true

  110. ///

  111. void getFitedYs(std::vector<double>& fitedYs){fitedYs = this->fitedYs;}

  112.  
  113. ///

  114. /// \brief 根据x获取拟合方程的y值

  115. /// \return 返回x对应的y值

  116. ///

  117. template<typename T>

  118. double getY(const T x) const

  119. {

  120. double ans(0);

  121. for (size_t i=0;i<factor.size();++i)

  122. {

  123. ans += factor[i]*pow((double)x,(int)i);

  124. }

  125. return ans;

  126. }

  127. ///

  128. /// \brief 获取斜率

  129. /// \return 斜率值

  130. ///

  131. double getSlope(){return factor[1];}

  132. ///

  133. /// \brief 获取截距

  134. /// \return 截距值

  135. ///

  136. double getIntercept(){return factor[0];}

  137. ///

  138. /// \brief 剩余平方和

  139. /// \return 剩余平方和

  140. ///

  141. double getSSE(){return sse;}

  142. ///

  143. /// \brief 回归平方和

  144. /// \return 回归平方和

  145. ///

  146. double getSSR(){return ssr;}

  147. ///

  148. /// \brief 均方根误差

  149. /// \return 均方根误差

  150. ///

  151. double getRMSE(){return rmse;}

  152. ///

  153. /// \brief 确定系数,系数是0~1之间的数,是数理上判定拟合优度的一个量

  154. /// \return 确定系数

  155. ///

  156. double getR_square(){return 1-(sse/(ssr+sse));}

  157. ///

  158. /// \brief 获取两个vector的安全size

  159. /// \return 最小的一个长度

  160. ///

  161. template<typename T>

  162. size_t getSeriesLength(const std::vector<typename T>& x

  163. ,const std::vector<typename T>& y)

  164. {

  165. return (x.size() > y.size() ? y.size() : x.size());

  166. }

  167. ///

  168. /// \brief 计算均值

  169. /// \return 均值

  170. ///

  171. template <typename T>

  172. static T Mean(const std::vector<T>& v)

  173. {

  174. return Mean(&v[0],v.size());

  175. }

  176. template <typename T>

  177. static T Mean(const T* v,size_t length)

  178. {

  179. T total(0);

  180. for (size_t i=0;i<length;++i)

  181. {

  182. total += v[i];

  183. }

  184. return (total / length);

  185. }

  186. ///

  187. /// \brief 获取拟合方程系数的个数

  188. /// \return 拟合方程系数的个数

  189. ///

  190. size_t getFactorSize(){return factor.size();}

  191. ///

  192. /// \brief 根据阶次获取拟合方程的系数,

  193. /// 如getFactor(2),就是获取y=a0+a1*x+a2*x^2+……+apoly_n*x^poly_n中a2的值

  194. /// \return 拟合方程的系数

  195. ///

  196. double getFactor(size_t i){return factor.at(i);}

  197. private:

  198. template<typename T>

  199. void calcError(const T* x

  200. ,const T* y

  201. ,size_t length

  202. ,double& r_ssr

  203. ,double& r_sse

  204. ,double& r_rmse

  205. ,bool isSaveFitYs=true

  206. )

  207. {

  208. T mean_y = Mean<T>(y,length);

  209. T yi(0);

  210. fitedYs.reserve(length);

  211. for (int i=0; i<length; ++i)

  212. {

  213. yi = getY(x[i]);

  214. r_ssr += ((yi-mean_y)*(yi-mean_y));//计算回归平方和

  215. r_sse += ((yi-y[i])*(yi-y[i]));//残差平方和

  216. if (isSaveFitYs)

  217. {

  218. fitedYs.push_back(double(yi));

  219. }

  220. }

  221. r_rmse = sqrt(r_sse/(double(length)));

  222. }

  223. template<typename T>

  224. void gauss_solve(int n

  225. ,std::vector<typename T>& A

  226. ,std::vector<typename T>& x

  227. ,std::vector<typename T>& b)

  228. {

  229. gauss_solve(n,&A[0],&x[0],&b[0]);

  230. }

  231. template<typename T>

  232. void gauss_solve(int n

  233. ,T* A

  234. ,T* x

  235. ,T* b)

  236. {

  237. int i,j,k,r;

  238. double max;

  239. for (k=0;k<n-1;k++)

  240. {

  241. max=fabs(A[k*n+k]); /*find maxmum*/

  242. r=k;

  243. for (i=k+1;i<n-1;i++){

  244. if (max<fabs(A[i*n+i]))

  245. {

  246. max=fabs(A[i*n+i]);

  247. r=i;

  248. }

  249. }

  250. if (r!=k){

  251. for (i=0;i<n;i++) /*change array:A[k]&A[r] */

  252. {

  253. max=A[k*n+i];

  254. A[k*n+i]=A[r*n+i];

  255. A[r*n+i]=max;

  256. }

  257. }

  258. max=b[k]; /*change array:b[k]&b[r] */

  259. b[k]=b[r];

  260. b[r]=max;

  261. for (i=k+1;i<n;i++)

  262. {

  263. for (j=k+1;j<n;j++)

  264. A[i*n+j]-=A[i*n+k]*A[k*n+j]/A[k*n+k];

  265. b[i]-=A[i*n+k]*b[k]/A[k*n+k];

  266. }

  267. }

  268.  
  269. for (i=n-1;i>=0;x[i]/=A[i*n+i],i--)

  270. for (j=i+1,x[i]=b[i];j<n;j++)

  271. x[i]-=A[i*n+j]*x[j];

  272. }

  273. };

  274. }

  275.  
  276.  
  277. #endif


 

 

为了防止重命名,把其放置于czy的命名空间中,此类主要两个函数:

1.求解线性拟合:

 

 
  1. ///

  2. /// \brief 直线拟合-一元回归,拟合的结果可以使用getFactor获取,或者使用getSlope获取斜率,getIntercept获取截距

  3. /// \param x 观察值的x

  4. /// \param y 观察值的y

  5. /// \param length x,y数组的长度

  6. /// \param isSaveFitYs 拟合后的数据是否保存,默认否

  7. ///

  8. template<typename T>

  9. bool linearFit(const std::vector<typename T>& x, const std::vector<typename T>& y,bool isSaveFitYs=false);

  10. template<typename T>

  11. bool linearFit(const T* x, const T* y,size_t length,bool isSaveFitYs=false);


 

 

2.多项式拟合:

 

 

 
  1. ///

  2. /// \brief 多项式拟合,拟合y=a0+a1*x+a2*x^2+……+apoly_n*x^poly_n

  3. /// \param x 观察值的x

  4. /// \param y 观察值的y

  5. /// \param length x,y数组的长度

  6. /// \param poly_n 期望拟合的阶数,若poly_n=2,则y=a0+a1*x+a2*x^2

  7. /// \param isSaveFitYs 拟合后的数据是否保存,默认是

  8. ///

  9. template<typename T>

  10. void polyfit(const std::vector<typename T>& x,const std::vector<typename T>& y,int poly_n,bool isSaveFitYs=true);

  11. template<typename T>

  12. void polyfit(const T* x,const T* y,size_t length,int poly_n,bool isSaveFitYs=true);


 

 

这两个函数都用模板函数形式写,主要是为了能使用于float和double两种数据类型

 

 

2.fit类的MFC示范程序

下面看看如何使用这个类,以MFC示范,使用了开源的绘图控件Hight-Speed Charting,使用方法见 http://blog.csdn.net/czyt1988/article/details/8740500

 

新建对话框文件,

对话框资源文件如图所示:

C++最小二乘法拟合-(线性拟合和多项式拟合)

加入下面的这些变量:

 

 
  1. std::vector<double> m_x,m_y,m_yploy;

  2. const size_t m_size;

  3. CChartLineSerie *m_pLineSerie1;

  4. CChartLineSerie *m_pLineSerie2;


由于m_size是常量,因此需要在构造函数进行初始化,如:

 

 

 
  1. ClineFitDlg::ClineFitDlg(CWnd* pParent /*=NULL*/)

  2. : CDialogEx(ClineFitDlg::IDD, pParent)

  3. ,m_size(512)

  4. ,m_pLineSerie1(NULL)

 

 

初始化两条曲线:

 

 
  1. CChartAxis *pAxis = NULL;

  2. pAxis = m_chartCtrl.CreateStandardAxis(CChartCtrl::BottomAxis);

  3. pAxis->SetAutomatic(true);

  4. pAxis = m_chartCtrl.CreateStandardAxis(CChartCtrl::LeftAxis);

  5. pAxis->SetAutomatic(true);

  6. m_x.resize(m_size);

  7. m_y.resize(m_size);

  8. m_yploy.resize(m_size);

  9. for(size_t i =0;i<m_size;++i)

  10. {

  11. m_x[i] = i;

  12. m_y[i] = i+randf(-25,28);

  13. m_yploy[i] = 0.005*pow(double(i),2)+0.0012*i+4+randf(-25,25);

  14. }

  15. m_chartCtrl.RemoveAllSeries();//先清空

  16. m_pLineSerie1 = m_chartCtrl.CreateLineSerie();

  17. m_pLineSerie1->SetSeriesOrdering(poNoOrdering);//设置为无序

  18. m_pLineSerie1->AddPoints(&m_x[0], &m_y[0], m_size);

  19. m_pLineSerie1->SetName(_T("线性数据"));

  20. m_pLineSerie2 = m_chartCtrl.CreateLineSerie();

  21. m_pLineSerie2->SetSeriesOrdering(poNoOrdering);//设置为无序

  22. m_pLineSerie2->AddPoints(&m_x[0], &m_yploy[0], m_size);

  23. m_pLineSerie2->SetName(_T("多项式数据"));

 

 

rangf是随机数生成函数,实现如下:

 

 
  1. double ClineFitDlg::randf(double min,double max)

  2. {

  3. int minInteger = (int)(min*10000);

  4. int maxInteger = (int)(max*10000);

  5. int randInteger = rand()*rand();

  6. int diffInteger = maxInteger - minInteger;

  7. int resultInteger = randInteger % diffInteger + minInteger;

  8. return resultInteger/10000.0;

  9. }


运行程序,如图所示

 

C++最小二乘法拟合-(线性拟合和多项式拟合)

线性拟合的使用如下:

 

 
  1. void ClineFitDlg::OnBnClickedButton1()

  2. {

  3. CString str,strTemp;

  4. czy::Fit fit;

  5. fit.linearFit(m_x,m_y);

  6. str.Format(_T("方程:y=%gx+%g\r\n误差:ssr:%g,sse=%g,rmse:%g,确定系数:%g"),fit.getSlope(),fit.getIntercept()

  7. ,fit.getSSR(),fit.getSSE(),fit.getRMSE(),fit.getR_square());

  8. GetDlgItemText(IDC_EDIT,strTemp);

  9. SetDlgItemText(IDC_EDIT,strTemp+_T("\r\n------------------------\r\n")+str);

  10. //在图上绘制拟合的曲线

  11. CChartLineSerie* pfitLineSerie1 = m_chartCtrl.CreateLineSerie();

  12. std::vector<double> x(2,0),y(2,0);

  13. x[0] = 0;x[1] = m_size-1;

  14. y[0] = fit.getY(x[0]);y[1] = fit.getY(x[1]);

  15. pfitLineSerie1->SetSeriesOrdering(poNoOrdering);//设置为无序

  16. pfitLineSerie1->AddPoints(&x[0], &y[0], 2);

  17. pfitLineSerie1->SetName(_T("拟合方程"));//SetName的作用将在后面讲到

  18. pfitLineSerie1->SetWidth(2);

  19. }


需要如下步骤:

 

 

  • 声明Fit类,用于头文件在czy命名空间中,因此需要显示声明命名空间名称czy::Fit fit;
  • 把观察数据输入进行拟合,由于是线性拟合,可以使用LinearFit函数,此函数把观察量的x值和y值传入即可进行拟合
  • 拟合完后,拟合的相关结果保存在czy::Fit里面,可以通过相关方法调用,方法在头文件中都有详细说明

 

 

运行结果如图所示:

C++最小二乘法拟合-(线性拟合和多项式拟合)

 

多项式拟合的使用如下:

 

 
  1. void ClineFitDlg::OnBnClickedButton2()

  2. {

  3. CString str;

  4. GetDlgItemText(IDC_EDIT1,str);

  5. if (str.IsEmpty())

  6. {

  7. MessageBox(_T("请输入阶次"),_T("警告"));

  8. return;

  9. }

  10. int n = _ttoi(str);

  11. if (n<0)

  12. {

  13. MessageBox(_T("请输入大于1的阶数"),_T("警告"));

  14. return;

  15. }

  16. czy::Fit fit;

  17. fit.polyfit(m_x,m_yploy,n,true);

  18. CString strFun(_T("y=")),strTemp(_T(""));

  19. for (int i=0;i<fit.getFactorSize();++i)

  20. {

  21. if (0 == i)

  22. {

  23. strTemp.Format(_T("%g"),fit.getFactor(i));

  24. }

  25. else

  26. {

  27. double fac = fit.getFactor(i);

  28. if (fac<0)

  29. {

  30. strTemp.Format(_T("%gx^%d"),fac,i);

  31. }

  32. else

  33. {

  34. strTemp.Format(_T("+%gx^%d"),fac,i);

  35. }

  36. }

  37. strFun += strTemp;

  38. }

  39. str.Format(_T("方程:%s\r\n误差:ssr:%g,sse=%g,rmse:%g,确定系数:%g"),strFun

  40. ,fit.getSSR(),fit.getSSE(),fit.getRMSE(),fit.getR_square());

  41. GetDlgItemText(IDC_EDIT,strTemp);

  42. SetDlgItemText(IDC_EDIT,strTemp+_T("\r\n------------------------\r\n")+str);

  43. //绘制拟合后的多项式

  44. std::vector<double> yploy;

  45. fit.getFitedYs(yploy);

  46. CChartLineSerie* pfitLineSerie1 = m_chartCtrl.CreateLineSerie();

  47. pfitLineSerie1->SetSeriesOrdering(poNoOrdering);//设置为无序

  48. pfitLineSerie1->AddPoints(&m_x[0], &yploy[0], yploy.size());

  49. pfitLineSerie1->SetName(_T("多项式拟合方程"));//SetName的作用将在后面讲到

  50. pfitLineSerie1->SetWidth(2);

  51. }


步骤如下:

 

 

  • 和线性拟合一样,声明Fit变量
  • 输入观察值,同时输入需要拟合的阶次,这里输入2阶,就是2项式拟合,最后的布尔变量是标定是否需要把拟合的结果点保存起来,保存点会根据观察的x值计算拟合的y值,保存结果点会花费更多的内存,如果拟合后需要绘制,设为true会更方便,如果只需要拟合的方程,可以设置为false
  • 拟合完后,拟合的相关结果保存在czy::Fit里面,可以通过相关方法调用,方法在头文件中都有详细说明

代码:

 
  1. for (int i=0;i<fit.getFactorSize();++i)

  2. {

  3. if (0 == i)

  4. {

  5. strTemp.Format(_T("%g"),fit.getFactor(i));

  6. }

  7. else

  8. {

  9. double fac = fit.getFactor(i);

  10. if (fac<0)

  11. {

  12. strTemp.Format(_T("%gx^%d"),fac,i);

  13. }

  14. else

  15. {

  16. strTemp.Format(_T("+%gx^%d"),fac,i);

  17. }

  18. }

  19. strFun += strTemp;

  20. }


是用于生成方程的,由于系数小于时,打印时会把负号“-”显示,而正数时却不会显示正号,因此需要进行判断,如果小于0就不用添加“+”号,如果大于0就添加“+”号

结果如下:

C++最小二乘法拟合-(线性拟合和多项式拟合)

 

 

源代码下载:

C++最小二乘法拟合-(线性拟合和多项式拟合)

 

转载于:https://my.oschina.net/2nmjeSMen3/blog/674377