Pytorch scatter_理解

Pytorch scatter_理解

scatter_(dim, index, src)将src中数据根据index中的索引按照dim的方向填充进调用scatter_的向量中;

1.dim=0

Pytorch scatter_理解

首先了解,dim=0,表示按行填充列数不变;scatter_参数中给定的index tensor,两个中括号中数值表示x(2,5)中第一行和第二行下标的数列数不变分别被填充到向量(3,5)的第一行和第二行;数值的下标表示取x中对应下标的值,然后填充到数值表示的行,列不变;

具体解释:比如第一个中括号第一个值0,表示取x中第0行第0列(0.3992),填充到(3,5)中的第0行第0列;第二个中括号(第二行)第一个值2,表示取x中第1行第0个值(0.5735),填充到(3,5)中的第2行第0列;

2.dim=1

Pytorch scatter_理解

dim=1,表示按列填充行数不变;同理上面例子,表示分别把1.23,填充到第一行第二列和第二行第三列;

Pytorch scatter_理解

同理:0.6737,是由于第一行有三次取值都放在第0个位置,所以最后的值覆盖掉了0.3376和0.2782,;0.9849是由于取x中第一行第一列放到(3,5)中第一行第一列;实质上是取中括号中数值的下标对应在x中的值,填充到(3,5)第x行第数值列,新值覆盖旧值;

 

;链接:https://pytorch.org/docs/stable/tensors.html?highlight=scatter_#torch.Tensor.scatter_