Pytorch张量如何得到具体值的索引

问题描述:

在python列表中,我们可以使用list.index(somevalue)。 pytorch如何做到这一点?
例如:Pytorch张量如何得到具体值的索引

a=[1,2,3] 
    print(a.index(2)) 

然后,1将输出。 pytorch tensor如何在不将其转换为python列表的情况下执行此操作?

我认为没有从list.index()到pytorch函数的直接翻译。但是,您可以使用tensor==number,然后使用nonzero()函数获得类似的结果。例如:

t = torch.Tensor([1, 2, 3]) 
print ((t == 2).nonzero()) 

这段代码返回

[大小1×1的torch.LongTensor]