TensorFlow、Numpy中的axis的理解
TensorFlow中有很多函数涉及到axis,比如tf.reduce_mean(),其函数原型如下:
def reduce_mean(input_tensor,
axis=None,
keepdims=None,
name=None,
reduction_indices=None,
keep_dims=None):
其中axis表示的是,对该维度进行求均值(默认情况下,是对所有值求均值)。
除了TensorFlow中,numpy中也经常遇到很多对矩阵操作的函数会涉及axis操作。比如np.mean(),其函数原型如下:
def mean(a, axis=None, dtype=None, out=None, keepdims=np._NoValue):
想要弄清楚如何处理涉及axis(维度)的操作,必须先明白axis是什么。
首先axis是维度,如果axis=0则对应着高; 如果axis=1则对应着行处理;如果axis=2则对应着列;如果axis=3…n(无法用直观的图来表示)。我相信很多人看到这还是会一头雾水。什么是高,行还有列。为了说明这个问题,我举个列子:
data=[[[1,2,3],[11,22,33]],[[4,5,6],[44,55,66]],[[10,11,12],[100,110,120]],[[7,8,9],[77,88,99]]]
data_np=np.array(data)
print(data_np)
[[[ 1 2 3]
[ 11 22 33]]
[[ 4 5 6]
[ 44 55 66]]
[[ 10 11 12]
[100 110 120]]
[[ 7 8 9]
[ 77 88 99]]]
如上面,可以将最外层[ ]去掉,可以发现有4组元素(这里的元素是矩阵),你可以将其理解为高。
再从这4组元素中选取一组,比如选择的是
[[ 1 2 3]
[ 11 22 33]]
然后将该组的最外层[ ]去掉,可以发现有2组元素分别为[ 1 2 3]和 [ 11 22 33],此时对应的是行。
在从这两组元素中选一组,比如选择的是
[ 11 22 33]
现在无需去掉最外层的[ ]了,一眼就能看出里面有3个元素。这就是对应的列。
理解了上面的分析后,很容易就知道(高,行,列)对应的其实就是改矩阵的shape.
print(data_np.shape):
(4,2,3)
现在弄清楚了axis的值与(高,行,列)的关系后,再来分析tf.reduce_mean()或者np.mean()等函数是如何对axis进行操作的。
data=[[[1,2,3],[11,22,33]],[[4,5,6],[44,55,66]],[[10,11,12],[100,110,120]],[[7,8,9],[77,88,99]]]
data_tensor=tf.constant(data,dtype=tf.float32)
mean_axis0=tf.reduce_mean(data_tensor,axis=0)
mean_axis1=tf.reduce_mean(data_tensor,axis=1)
mean_axis2=tf.reduce_mean(data_tensor,axis=2)
with tf.Session() as sess:
print(sess.run(mean_axis0))
print(sess.run(mean_axis1))
print(sess.run(mean_axis2))
针对上述代码,我们先对axis=0维度的数据处理进行分析。
首先对上述data数据进行立体化变换,如下图(本人本想用软件来绘制3D的矩阵叠加效果,可惜找了很多软件都不适合,也许是本人寻找的还不够,欢迎有知道可以绘制3D的矩阵叠加效果的朋友们,能够分享一下。感激…)
如上如,axis=0的维度数据求均值,
[[(1+4+10+7)/4 (2+5+11+8)/4 (3+6+12+9)/4]
[(11+44+100+77)/4 (22+55+110+88)/4 (33+66+120+99)/4]]
=
[[ 5.5 6.5 7.5 ]
[58. 68.75 79.5 ]]
同理,对axis=1的维度数据求均值
[[(1+11)/2 (2+22)/2 (3+33)/2]
[(4+44)/2 (5+55)/2 (6+66)/2]
[(10+100)/2 (11+110)/2 (12+120)/2]
[(7+77)/2 (8+88)/2 (9+99)/2]]
=
[[ 6. 12. 18. ]
[24. 30. 36. ]
[55. 60.5 66. ]
[42. 48. 54. ]]
同理可得axis=2维度的数据平均值为(过程留给读者去推,运算结果如下):
[[ 2. 22.]
[ 5. 55.]
[ 11. 110.]
[ 8. 88.]]
在python的世界里,有很多时候都需要对数据进行维度的操作,如果对axis理解的不透的话,很容易找不着方向。
更多人工智能技术干货请关注: