Optimization Method -- Gradient Descent with Momentum Gradient Descent (转)

Gradient Descent

在機器學習的過程中,常需要將 Cost Function 的值減小,通常用 Gradient Descent 來做最佳化的方法來達成。但是用 Gradient Descent 有其缺點,例如,很容易卡在 Local Minimum。

Gradient Descent 的公式如下:

關於Gradient Descent的公式解說,請參考:Optimization Method -- Gradient Descent & AdaGrad

Getting Stuck in Local Minimum

舉個例子,如果 Cost Function 為  ,有 Local Minimum  ,畫出來的圖形如下:

Optimization Method -- Gradient Descent with Momentum Gradient Descent (转)

當執行 Gradient Descent 的時候,則會卡在 Local Minimum,如下圖:

Optimization Method -- Gradient Descent with Momentum Gradient Descent (转)

解決卡在 Local Minimum 的方法,可加入 Momentum ,使它在 Gradient 等於零的時候,還可繼續前進。

Gradient Descent with Momentum

Momentum 的概念如下: 當一顆球從斜坡上滾到平地時,球在平地仍會持續滾動,因為球具有動量,也就是說,它的速度跟上一個時間點的速度有關。

模擬 Momentum的方式很簡單,即是把上一個時間點用 Gradient 得出的變化量也考慮進去。

Gradient Descent with Momentum 的公式如下:

其中  為  時間點,修正  值所用的變化量,而  則是  時間點的修正量,而  則是用來控制在  時間點中的  具有上個時間點的  值的比例。 好比說,在  時間點時,球的速度會跟  時間點有關。 而  ,則是  時間點算出之 Gradient  乘上 Learning Rate  後,在  中所占的比例。

舉前述例子,若起始參數為  ,畫出目標函數,藍點為起始點  的位置:

Optimization Method -- Gradient Descent with Momentum Gradient Descent (转)

用 Gradient Descent with Momentum 來更新  的值,如下:

化減後得:

設初始化值  ,參數  ,代入  ,則:

更新圖上的藍點,如下圖:

Optimization Method -- Gradient Descent with Momentum Gradient Descent (转)

再往下走一步,  的值如下:

更新圖上的藍點,如下圖:

Optimization Method -- Gradient Descent with Momentum Gradient Descent (转)

在以上兩步中,可發現  的值逐漸變大。由於一開始  都是零,它會跟前一個時間點的值有關,所以看起來就好像是球從斜坡上滾下來時,慢慢加速,而在球經過 Local Minimum時,也會慢慢減速,不會直接卡在 Local Minimum 。整個過程如下圖:

Optimization Method -- Gradient Descent with Momentum Gradient Descent (转)

動畫版:

Optimization Method -- Gradient Descent with Momentum Gradient Descent (转)

Implementation

再來進入實作的部分
首先,開啟新的檔案 momentum.py 並貼上以下程式碼:

momentum.py
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import matplotlib.pyplot as plt
import numpy as np

def func(x,y):
  return (0.3*y**3+y**2+0.3*x**3+x**2)

def func_grad(x,y):
  return (0.9*x**2+2*x, 0.9*y**2+2*y )

def plot_func(xt,yt,c='r'):
  fig = plt.figure()
  ax = fig.gca(projection='3d',
        elev=7., azim=-175)
  X, Y = np.meshgrid(np.arange(-5, 5, 0.25), np.arange(-5, 5, 0.25))
  Z = func(X,Y) 
  surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, 
    cmap=cm.coolwarm, linewidth=0.1, alpha=0.3)
  ax.set_zlim(-20, 100)
  ax.scatter(xt, yt, func(xt,yt),c=c, marker='o' )
  ax.set_title("x=%.5f, y=%.5f, f(x,y)=%.5f"%(xt,yt,func(xt,yt))) 
  plt.show()
  plt.close()

def run_grad():
  xt = 3 
  yt = 3 
  eta = 0.1
  plot_func(xt,yt,'r')
  for i in range(20):
    gxt, gyt = func_grad(xt,yt)
    xt = xt - eta * gxt
    yt = yt - eta * gyt
    if xt < -5 or yt < -5 or xt > 5 or yt > 5:
      break
    plot_func(xt,yt,'r')

def run_momentum():
  xt = 3 
  yt = 3 
  eta = 0.2
  beta = 0.9
  plot_func(xt,yt,'b')
  delta_x = 0
  delta_y = 0
  for i in range(20):
    gxt, gyt = func_grad(xt,yt)
    delta_x = beta * delta_x + (1-beta)*eta*gxt
    delta_y = beta * delta_y + (1-beta)*eta*gyt
    xt = xt - delta_x
    yt = yt - delta_y 
    if xt < -5 or yt < -5 or xt > 5 or yt > 5:
      break
    plot_func(xt,yt,'b')

其中, func(x,y) 為目標函數,func_grad(x,y) 為目標函數的 gradient ,而 plot_func(xt,yt,c='r') 可畫出目標函數的曲面, run_grad() 用來執行 Gradient Descent , run_momentum() 用來執行 Gradient Descent with Momentum 。 xt 和 yt 對應到前例的  ,而 eta 為 Learning Rate 。 for i in range(20) 表示最多會跑20個迴圈,而 if xt < -5 or yt < -5 or xt > 5 or yt > 5 表示,如果 xt 和 yt 超出邊界,則會先結束迴圈。

到 python console 執行:

>>> import momentum

執行 Gradient Descent ,指令如下:

>>> momentum.run_grad()

則程式會逐一畫出整個過程:

Optimization Method -- Gradient Descent with Momentum Gradient Descent (转)

Optimization Method -- Gradient Descent with Momentum Gradient Descent (转)

Optimization Method -- Gradient Descent with Momentum Gradient Descent (转)

以此類推

執行 Gradient Descent with Momentum ,指令如下:

>>> momentum.run_momentum()

則程式會逐一畫出整個過程:

Optimization Method -- Gradient Descent with Momentum Gradient Descent (转)

Optimization Method -- Gradient Descent with Momentum Gradient Descent (转)

Optimization Method -- Gradient Descent with Momentum Gradient Descent (转)

以此類推

Reference

Visualizing Optimization Algos

http://imgur.com/a/Hqolp