np.expand_dims的用法

np.expand_dims(a,axis=?)

即扩展维度,np.expand_dims(a,axis=)即在 a 的相应的axis轴上扩展维度
a = np.array([[1,2],[3,5]])

y = np.expand_dims(a, axis=2)
z = np.expand_dims(a, axis=1)
print(a.shape)
print(y.shape)
print(z.shape)
输出
(2, 2)
(2, 2, 1)
(2, 1, 2)

y 变成了 [

[ [1], [2] ],

 [ [3], [5] ]

]

 

z 变成了[   

     [   [1,2],[3,5]   ]     

              ]

np.expand_dims的用法