PyTorch:如何得到张量的形状的INT
问题描述:
在numpy的,V.shape
列表给出了五PyTorch:如何得到张量的形状的INT
在tensorflow V.get_shape().as_list()
给出了V的尺寸的整数列表
在pytorch中,V.size()
给出了一个大小对象,但是如何将它转换为整数?
答
简单list(var.size())
,例如:
>>> import torch
>>> from torch.autograd import Variable
>>> from torch import IntTensor
>>> var = Variable(IntTensor([[1,0],[0,1]]))
>>> var
Variable containing:
1 0
0 1
[torch.IntTensor of size 2x2]
>>> var.size()
torch.Size([2, 2])
>>> list(var.size())
[2, 2]
答
如果你NumPy
ISH语法的粉丝,那么就tensor.shape
。
In [3]: ar = torch.rand(3, 3)
In [4]: ar.shape
Out[4]: torch.Size([3, 3])
# method-1
In [7]: list(ar.shape)
Out[7]: [3, 3]
# method-2
In [8]: [*ar.shape]
Out[8]: [3, 3]
# method-3
In [9]: [*ar.size()]
Out[9]: [3, 3]
P.S.:请注意tensor.shape
是tensor.size()
的别名,尽管tensor.shape
是所讨论的张量的属性,而tensor.size()
是函数。只有在启用了GPU的机器上才能使用。
哪部分代码只适用于GPU机器? 'tensor.shape'? – rasen58