Batch Normalization的TensorFlow实现
https://www.toutiao.com/a6694908705214824963/
tf.nn.moments函数
函数定义如下:
def moments(x, axes, name=None, keep_dims=False)
1.函数的输入
x: 输入数据,格式一般为:[batchsize, height, width, kernels]
axes: List,在哪个维度上计算,比如:[0, 1, 2]
name: 操作的名称
keep_dims: 是否保持维度
2.函数的输出
mean: 均值
variance: 方差
3.使用举例
img = tf.Variable(tf.random_normal([128, 32, 32, 64]))
axis = list(range(len(img.get_shape()) - 1))
mean, variance = tf.nn.moments(img, axis)
tf.nn.batch_normalization函数
函数定义如下:
def batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name=None)
在使用batch_normalization的时候,需要去除网络中的bias。
1.函数的输入
x: 输入的Tensor数据
mean: Tensor的均值
variance: Tensor的方差
offset: offset Tensor, 一般初始化为0,可训练
scale: scale Tensor,一般初始化为1,可训练
variance_epsilon: 一个小的浮点数,避免除数为0,一般取值0.001
name: 操作的名称
2.算法原理
李宏毅深度学习笔记