Project Code

class RefinementStage(nn.Module):
    def __init__(self, in_channels, out_channels, num_heatmaps, num_pafs):
        super().__init__()
        self.trunk = nn.Sequential(
            RefinementStageBlock(in_channels, out_channels),
            RefinementStageBlock(out_channels, out_channels),
            RefinementStageBlock(out_channels, out_channels),
            RefinementStageBlock(out_channels, out_channels),
            RefinementStageBlock(out_channels, out_channels)
        )
        self.heatmaps = nn.Sequential(
            conv(out_channels, out_channels, kernel_size=1, padding=0, bn=False),
            conv(out_channels, num_heatmaps, kernel_size=1, padding=0, bn=False, relu=False)
        )
        self.pafs = nn.Sequential(
            conv(out_channels, out_channels, kernel_size=1, padding=0, bn=False),
            conv(out_channels, num_pafs, kernel_size=1, padding=0, bn=False, relu=False)
        )

    def forward(self, x):
        trunk_features = self.trunk(x)
        heatmaps = self.heatmaps(trunk_features)
        pafs = self.pafs(trunk_features)
        return [heatmaps, pafs]

The RefinementStage class is a PyTorch module that refines the predictions (heatmaps and part affinity fields, or PAFs) from the previous stage of the pose estimation pipeline. It uses a series of RefinementStageBlocks to process input features and outputs improved heatmaps and PAFs.


Class Definition

class RefinementStage(nn.Module):
    def __init__(self, in_channels, out_channels, num_heatmaps, num_pafs):
        super().__init__()

Components

  1. self.trunk

    self.trunk = nn.Sequential(
        RefinementStageBlock(in_channels, out_channels),
        RefinementStageBlock(out_channels, out_channels),
        RefinementStageBlock(out_channels, out_channels),
        RefinementStageBlock(out_channels, out_channels),
        RefinementStageBlock(out_channels, out_channels)
    )
    
  2. self.heatmaps

    self.heatmaps = nn.Sequential(
        conv(out_channels, out_channels, kernel_size=1, padding=0, bn=False),
        conv(out_channels, num_heatmaps, kernel_size=1, padding=0, bn=False, relu=False)
    )
    
  3. self.pafs

    self.pafs = nn.Sequential(
        conv(out_channels, out_channels, kernel_size=1, padding=0, bn=False),
        conv(out_channels, num_pafs, kernel_size=1, padding=0, bn=False, relu=False)
    )
    

Forward Method

def forward(self, x):
    trunk_features = self.trunk(x)
    heatmaps = self.heatmaps(trunk_features)
    pafs = self.pafs(trunk_features)
    return [heatmaps, pafs]

Purpose in the Model