matplotlib.pyplot.imshow()函数的使用

首先附上一个官方原版的说明:

matplotlib.pyplot.imshow(Xcmap=Nonenorm=Noneaspect=Noneinterpolation=Nonealpha=Nonevmin=Nonevmax=Noneorigin=None,extent=Noneshape=Nonefilternorm=1filterrad=4.0imlim=Noneresample=Noneurl=Nonehold=Nonedata=None**kwargs)

特别注意一下几个参数:

X : array_like, shape (n, m) or (n, m, 3) or (n, m, 4)

X may be an array or a PIL image. If X is an array, it can have the following shapes and types:

  • MxN – values to be mapped (float or int)
  • MxNx3 – RGB (float or uint8)
  • MxNx4 – RGBA (float or uint8)

The value for each component of MxNx3 and MxNx4 float arrays should be in the range 0.0 to 1.0. MxN arrays are mapped to colors based on the norm (mapping scalar to scalar) and the cmap (mapping the normed scalar to a color).

在deep learning 可视化中,经常需要可视化N×C×H×W的数据,而imshow的输入只能为(n, m) or (n, m, 3) or (n, m, 4),所以必须将N×C×H×W的数据转换,否则就会出错。比如,我想对mnist数据集进行一个数字的可视化,第一次是这么写的:

plt.imshow(test_data[1], cmap='gray')##torch.Size([1, 28, 28])
plt.title('%i' % test_label[1])
plt.show()

运行时报错:

raise TypeError("Invalid dimensions for image data")

TypeError: Invalid dimensions for image data

原因是灰度图必须按照二维的格式进行输入,即 (n, m),而我输入的是(C×H×W),本质上也是个三维格式,而且顺序也不对。

正确的做法如下:

plt.imshow(test_data[1].squeeze(), cmap='gray')
plt.title('%i' % test_label[1])
plt.show()

 这样的话就把输入数据变成(n, m)格式的了,运行效果如下:

matplotlib.pyplot.imshow()函数的使用

 

关于cmap参数和函数的详细用法,参考以下两篇博客即可:

http://www.cnblogs.com/denny402/p/5122594.html

https://liam.page/2014/09/11/matplotlib-tutorial-zh-cn/

 

REF:

https://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.imshow

https://blog.csdn.net/iamzhangzhuping/article/details/50803636

https://*.com/questions/36431496/typeerror-invalid-dimensions-for-image-data-when-plotting-array-with-imshow