如何设置学习率


1.三角法(一)

如何设置学习率
import matplotlib.pyplot as plt


def get_learning_rate(base_lr, step, epoch):
    flag = 1.0
    if epoch % 2 == 1:
        flag *= -1
    return base_lr + flag * step


if __name__ == '__main__':
    max_lr = 1.0
    learning_rate = 0.0
    epoches = 10
    batches = 50
    history_lr = []
    for epoch in range(epoches):
        for batch in range(batches):
            learning_rate = get_learning_rate(learning_rate, max_lr / (batches + 10), epoch)
            history_lr.append(learning_rate)
    iterations = [i for i in range(epoches * batches)]
    plt.xlabel('iterations')
    plt.ylabel('learning_rate', rotation=90)
    plt.plot(iterations, history_lr)
    plt.show()

2.三角法(二)

如何设置学习率

import matplotlib.pyplot as plt


def get_learning_rate(base_lr, step, epoch):
    flag = 1.0
    if epoch % 2 == 1:
        flag *= -1
    return base_lr + flag * step


if __name__ == '__main__':
    max_lr = 1.0
    learning_rate = 0.0
    epoches = 10
    batches = 50
    history_lr = []
    for epoch in range(epoches):
        for batch in range(batches):
            learning_rate = get_learning_rate(learning_rate, max_lr / (batches + 10), epoch)
            history_lr.append(learning_rate)
        if epoch % 2 == 1:########
            max_lr /= 2 ##########
    iterations = [i for i in range(epoches * batches)]
    plt.xlabel('iterations')
    plt.ylabel('learning_rate', rotation=90)
    plt.plot(iterations, history_lr)
    plt.show()


可以看到第二种与第一种的差距在于对初始学习率的最大值进行了折半处理(22,23行)

如何设置学习率