【语义分割系列】PointRend源码注释

小白一个,理解错误欢迎大佬指正。下面的流程按语义分割框架deeplabv3 + PointRend做的注释。deeplabv3 的主干网络是xception65

原理图:代码主要流程看完下面的介绍再返回头来看看这张图应该就更清晰了.不过这个原理图和代码对应存在点问题。

代码中的fine-grained是原图的1/4大小,不像下面这个是与原图大小一致的。后面的就都一样了

      【语义分割系列】PointRend源码注释

                                                                                                        图1

1.PointRend提出原因:

    传统语义分割网络,在进行一系列卷积池化后。会得到一定分辨率的featuremap图。这个featuremap图一般大小为原图的  1/8    1/16或者1/32 等等吧,其上的点就有了类别标签了,知道了某个像素归属于某类。然后通过一定的上采样方法将其恢复到原图大小,这样就得到原图的语义分割结果了,可以想象,上采样后的物体边缘会有不准确情况。这个PointRend就是要修正下边缘。将featuremap上的点按照一定规则做了个不稳定性排序,然后找出最不稳定的N个点(认为其归属不明,边界混乱)对其精修。可见,这个方法是在某种语义分割的结果之上做的工作

2.PointRend训练流程:

a.对featuremap上的点做不稳定排序,选取N个点出来。代码中N是8096。

具体代码为:points = sampling_points(out, x.shape[-1] // 16, self.k, self.beta)

b.在xception65的第一层上对应的N个点的特征提出来。

例子用的主干网络为xception65,那就以它为例学了。这个网络输出c1,c2,c3,c4。其中c1是较高分辨率下的featuremap(1/4),c4是最终的featuremap(1/16).将上面N个点在这两个图上的对应特征提出来。

具体代码为: coarse = point_sample(out, points, align_corners=False)
                         fine = point_sample(res2, points, align_corners=False)

c.将N个点的对应位置的特征粘合到一起。torch.cat函数实现 例如 C1的特征是[1, 19, 8096]  C2的特征是[1, 1256 8096] 那结果就是[1, 275, 8096]大小呗。

具体代码为:  eature_representation = torch.cat([coarse, fine], dim=1)

d.使用MLP进行细分预测。

具体代码为:  rend = self.mlp(feature_representation)

3.PointRend预测流程:

 与训练流程基本一致。

4.PointRend完整源码:

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision.models._utils import IntermediateLayerGetter
from .model_zoo import MODEL_REGISTRY
from .segbase import SegBaseModel
from ..config import cfg


@MODEL_REGISTRY.register(name='PointRend')
class PointRend(SegBaseModel):
    def __init__(self):
        super(PointRend, self).__init__(need_backbone=False)
        model_name = cfg.MODEL.POINTREND.BASEMODEL
        self.backbone =  MODEL_REGISTRY.get(model_name)()

        self.head = PointHead(num_classes=self.nclass)

    def forward(self, x):
        c1, _, _, c4 = self.backbone.encoder(x)#提取featuremap的位置

        out = self.backbone.head(c4, c1)
        
        result = {'res2': c1, 'coarse': out}
        result.update(self.head(x, result["res2"], result["coarse"]))
        if not self.training:
            return (result['fine'],)
        return result


class PointHead(nn.Module):
    def __init__(self, in_c=275, num_classes=19, k=3, beta=0.75):
        super().__init__()
        self.mlp = nn.Conv1d(in_c, num_classes, 1)
        self.k = k
        self.beta = beta

    def forward(self, x, res2, out):
        """
        1. Fine-grained features are interpolated from res2 for DeeplabV3
        2. During training we sample as many points as there are on a stride 16 feature map of the input
        3. To measure prediction uncertainty
           we use the same strategy during training and inference: the difference between the most
           confident and second most confident class probabilities.
        """
        if not self.training:
            return self.inference(x, res2, out)

        points = sampling_points(out, x.shape[-1] // 16, self.k, self.beta)#提取点的位置

        coarse = point_sample(out, points, align_corners=False)#提C4特征位置
        fine = point_sample(res2, points, align_corners=False)#提C1特征位置

        feature_representation = torch.cat([coarse, fine], dim=1)#特征粘合

        rend = self.mlp(feature_representation)#mlp预测识别

        return {"rend": rend, "points": points}

    @torch.no_grad()
    def inference(self, x, res2, out):
        """
        During inference, subdivision uses N=8096
        (i.e., the number of points in the stride 16 map of a 1024×2048 image)
        """
        num_points = 8096
        
        while out.shape[-1] != x.shape[-1]:
            out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=True)

            points_idx, points = sampling_points(out, num_points, training=self.training)

            coarse = point_sample(out, points, align_corners=False)
            fine = point_sample(res2, points, align_corners=False)

            feature_representation = torch.cat([coarse, fine], dim=1)

            rend = self.mlp(feature_representation)

            B, C, H, W = out.shape
            points_idx = points_idx.unsqueeze(1).expand(-1, C, -1)
            out = (out.reshape(B, C, -1)
                      .scatter_(2, points_idx, rend)
                      .view(B, C, H, W))
            
        return {"fine": out}


def point_sample(input, point_coords, **kwargs):
    """
    From Detectron2, point_features.py#19
    A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
    Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
    [0, 1] x [0, 1] square.
    Args:
        input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
        point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
        [0, 1] x [0, 1] normalized point coordinates.
    Returns:
        output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
            features for points in `point_coords`. The features are obtained via bilinear
            interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
    """
    add_dim = False
    if point_coords.dim() == 3:
        add_dim = True
        point_coords = point_coords.unsqueeze(2)
    output = F.grid_sample(input, 2.0 * point_coords - 1.0)#, **kwargs)
    if add_dim:
        output = output.squeeze(3)
    return output


@torch.no_grad()
def sampling_points(mask, N, k=3, beta=0.75, training=True):
    """
    Follows 3.1. Point Selection for Inference and Training
    In Train:, `The sampling strategy selects N points on a feature map to train on.`
    In Inference, `then selects the N most uncertain points`
    Args:
        mask(Tensor): [B, C, H, W]
        N(int): `During training we sample as many points as there are on a stride 16 feature map of the input`
        k(int): Over generation multiplier
        beta(float): ratio of importance points
        training(bool): flag
    Return:
        selected_point(Tensor) : flattened indexing points [B, num_points, 2]
    """
    assert mask.dim() == 4, "Dim must be N(Batch)CHW"
    device = mask.device
    B, _, H, W = mask.shape
    mask, _ = mask.sort(1, descending=True)

    if not training:
        H_step, W_step = 1 / H, 1 / W
        N = min(H * W, N)
        uncertainty_map = -1 * (mask[:, 0] - mask[:, 1])
        _, idx = uncertainty_map.view(B, -1).topk(N, dim=1)

        points = torch.zeros(B, N, 2, dtype=torch.float, device=device)
        points[:, :, 0] = W_step / 2.0 + (idx  % W).to(torch.float) * W_step
        points[:, :, 1] = H_step / 2.0 + (idx // W).to(torch.float) * H_step
        return idx, points

    # Official Comment : point_features.py#92
    # It is crucial to calculate uncertanty based on the sampled prediction value for the points.
    # Calculating uncertainties of the coarse predictions first and sampling them for points leads
    # to worse results. To illustrate the difference: a sampled point between two coarse predictions
    # with -1 and 1 logits has 0 logit prediction and therefore 0 uncertainty value, however, if one
    # calculates uncertainties for the coarse predictions first (-1 and -1) and sampe it for the
    # center point, they will get -1 unceratinty.

    over_generation = torch.rand(B, k * N, 2, device=device)
    over_generation_map = point_sample(mask, over_generation, align_corners=False)

    uncertainty_map = -1 * (over_generation_map[:, 0] - over_generation_map[:, 1])
    _, idx = uncertainty_map.topk(int(beta * N), -1)

    shift = (k * N) * torch.arange(B, dtype=torch.long, device=device)

    idx += shift[:, None]

    importance = over_generation.view(-1, 2)[idx.view(-1), :].view(B, int(beta * N), 2)
    coverage = torch.rand(B, N - int(beta * N), 2, device=device)
    return torch.cat([importance, coverage], 1).to(device)