Project Code

class InitialStage(nn.Module):
    def __init__(self, num_channels, num_heatmaps, num_pafs):
        super().__init__()
        self.trunk = nn.Sequential(
            conv(num_channels, num_channels, bn=False),
            conv(num_channels, num_channels, bn=False),
            conv(num_channels, num_channels, bn=False)
        )
        self.heatmaps = nn.Sequential(
            conv(num_channels, 512, kernel_size=1, padding=0, bn=False),
            conv(512, num_heatmaps, kernel_size=1, padding=0, bn=False, relu=False)
        )
        self.pafs = nn.Sequential(
            conv(num_channels, 512, kernel_size=1, padding=0, bn=False),
            conv(512, num_pafs, kernel_size=1, padding=0, bn=False, relu=False)
        )

    def forward(self, x):
        trunk_features = self.trunk(x)
        heatmaps = self.heatmaps(trunk_features)
        pafs = self.pafs(trunk_features)
        return [heatmaps, pafs]

The InitialStage class is a PyTorch module that represents the first stage of the pose estimation pipeline. It processes feature maps from the backbone network (or the CPM module) and generates two outputs: heatmaps and part affinity fields (PAFs). These outputs are essential for detecting keypoints and their connections in human pose estimation.


Class Definition

class InitialStage(nn.Module):
    def __init__(self, num_channels, num_heatmaps, num_pafs):
        super().__init__()

Components

  1. self.trunk

    self.trunk = nn.Sequential(
        conv(num_channels, num_channels, bn=False),
        conv(num_channels, num_channels, bn=False),
        conv(num_channels, num_channels, bn=False)
    )
    
  2. self.heatmaps

    self.heatmaps = nn.Sequential(
        conv(num_channels, 512, kernel_size=1, padding=0, bn=False),
        conv(512, num_heatmaps, kernel_size=1, padding=0, bn=False, relu=False)
    )
    
  3. self.pafs

    self.pafs = nn.Sequential(
        conv(num_channels, 512, kernel_size=1, padding=0, bn=False),
        conv(512, num_pafs, kernel_size=1, padding=0, bn=False, relu=False)
    )
    

Forward Method

def forward(self, x):
    trunk_features = self.trunk(x)
    heatmaps = self.heatmaps(trunk_features)
    pafs = self.pafs(trunk_features)
    return [heatmaps, pafs]

Purpose in the Model