numpy数组中reshape和squeeze函数的使用
参考了:http://blog.****.net/zenghaitao0128/article/details/78512715,作了一些自己的补充。
numpy中的reshape函数和squeeze函数是深度学习代码编写中经常使用的函数,需要深入的理解。
其中,reshape函数用于调整数组的轴和维度,而squeeze函数的用法如下,
语法:numpy.squeeze(a,axis = None)
1)a表示输入的数组;
2)axis用于指定需要删除的维度,但是指定的维度必须为单维度,否则将会报错;
3)axis的取值可为None 或 int 或 tuple of ints, 可选。若axis为空,则删除所有单维度的条目;
4)返回值:数组
5) 不会修改原数组;
作用:从数组的形状中删除单维度条目,即把shape中为1的维度去掉
举例:
numpy的reshape和squeeze函数:
import numpy as np
e = np.arange(10)
print(e)
一维数组:[0 1 2 3 4 5 6 7 8 9]
f = e.reshape(1,1,10)
print(f)
三维数组:(第三个方括号里有十个元素)
[[[0 1 2 3 4 5 6 7 8 9]]],前两维的秩为1
g = f.reshape(1,10,1)
print(g)
三维数组:(第二个方括号里有十个元素)
[[[0] [1] [2] [3] [4] [5] [6] [7] [8] [9]]]
h = e.reshape(10,1,1)
print(h)
三维数组:(第一个方括号里有10个元素) [[[0]] [[1]] [[2]] [[3]] [[4]] [[5]] [[6]] [[7]] [[8]] [[9]]]
利用squeeze可以把数组中的1维度去掉(从0开始指定轴),以下为不加参数axis,去掉所有1维的轴:
m = np.squeeze(h)
print(m)
以下指定去掉第几轴
n = np.squeeze(h,2)
print(n)
去掉第三轴,变成二维数组,维度为(10,1): [[0] [1] [2] [3] [4] [5] [6] [7] [8] [9]]
再举一个例子:
p = np.squeeze(g,2)
print(p)
去掉第2轴,得到二维数组,维度为(1,10):
[[0 1 2 3 4 5 6 7 8 9]]
p = np.squeeze(g,0)
print(p)
去掉第0轴,得到二维数组,维度为(10,1):
[[0] [1] [2] [3] [4] [5] [6] [7] [8] [9]]
在matplotlib画图中,非单维的数组在plot时候会出现问题,(1,nx)不行,但是(nx, )可以,(nx,1)也可以。
如下:
import matplotlib.pyplot as plt
squares =np.array([[1,4,9,16,25]])
print(squares.shape)
square的维度为(1,5),无法画图:
做如下修改:
plt.plot(np.squeeze(squares))
plt.show()
square的维度为(5,),可以画图:
或者做如下修改
squares1 = squares.reshape(5,1)
plt.plot(squares1)
plt.show()
square的维度为(5,1),可以画图: