Pytorch.squeeze和unsqueeze

简介和用法

squeeze和unsqueeze的作用与其翻译基本一致,被作用维度压缩和解压缩.用法相对简单,具体如下:
tensor_unsqueeze = tensor.unsqueeze(dim)
若tensor存在n个维度,则dim的取值为[-n+1,n]区间的整数,且dim的取值不能为空.
tensor_squeeze = tensor.squeeze(dim)
若tensor存在n个维度,则dim的取值为[-n,n-1]区间的证书,但dim的值可以为空

具体实例

unsqueeze

Pytorch.squeeze和unsqueeze
观察上图,不难发现规律,若用test来存储set2.unsqueeze(dim)的执行结果,那test的第dim个维度值必为1,而其他维度值按set2的维度值照抄即可.也就是说,若set2的维度数为n,unsqueeze(dim)后维度数变为n+1,而多出的维度值为1,并且位于第dim维(维度值的序号从0开始编号).

Squeeze

有参数

Pytorch.squeeze和unsqueeze
tensor.squeeze(dim)中的参数dim,表示对tensor的第dim维度进行压缩.若第dim维度不为1,则不做任何处理,若为1,则将该维度消除,即原先的维度数为n,则消除后维度数变为n-1.

无参数

Pytorch.squeeze和unsqueeze
若直接使用tensor.squeeze(),并不传递参数,此时将会对tensor的所有维度值进行遍历,若维度值为1,则消除,若维度值不为1,则保留.