Project Code

class PoseEstimationWithMobileNet(nn.Module):
    def __init__(self, num_refinement_stages=1, num_channels=128, num_heatmaps=19, num_pafs=38):
        super().__init__()
        self.model = nn.Sequential(
            conv(     3,  32, stride=2, bias=False),
            conv_dw( 32,  64),
            conv_dw( 64, 128, stride=2),
            conv_dw(128, 128),
            conv_dw(128, 256, stride=2),
            conv_dw(256, 256),
            conv_dw(256, 512),  # conv4_2
            conv_dw(512, 512, dilation=2, padding=2),
            conv_dw(512, 512),
            conv_dw(512, 512),
            conv_dw(512, 512),
            conv_dw(512, 512)   # conv5_5
        )
        self.cpm = Cpm(512, num_channels)

        self.initial_stage = InitialStage(num_channels, num_heatmaps, num_pafs)
        self.refinement_stages = nn.ModuleList()
        for idx in range(num_refinement_stages):
            self.refinement_stages.append(RefinementStage(num_channels + num_heatmaps + num_pafs, num_channels,
                                                          num_heatmaps, num_pafs))

    def forward(self, x):
        backbone_features = self.model(x)
        backbone_features = self.cpm(backbone_features)

        stages_output = self.initial_stage(backbone_features)
        for refinement_stage in self.refinement_stages:
            stages_output.extend(
                refinement_stage(torch.cat([backbone_features, stages_output[-2], stages_output[-1]], dim=1)))

        return stages_output

The PoseEstimationWithMobileNet class is the main model for lightweight human pose estimation. It combines a MobileNet-based backbone for feature extraction with additional stages for keypoint detection and refinement. The model outputs heatmaps (for keypoint locations) and part affinity fields (PAFs, for connections between keypoints).


Class Definition

class PoseEstimationWithMobileNet(nn.Module):
    def __init__(self, num_refinement_stages=1, num_channels=128, num_heatmaps=19, num_pafs=38):
        super().__init__()

Components

  1. Backbone (self.model)

    self.model = nn.Sequential(
        conv(     3,  32, stride=2, bias=False),
        conv_dw( 32,  64),
        conv_dw( 64, 128, stride=2),
        conv_dw(128, 128),
        conv_dw(128, 256, stride=2),
        conv_dw(256, 256),
        conv_dw(256, 512),  # conv4_2
        conv_dw(512, 512, dilation=2, padding=2),
        conv_dw(512, 512),
        conv_dw(512, 512),
        conv_dw(512, 512),
        conv_dw(512, 512)   # conv5_5
    )
    
  2. Convolutional Pose Machine (self.cpm)

    self.cpm = Cpm(512, num_channels)
    
  3. Initial Stage (self.initial_stage)

    self.initial_stage = InitialStage(num_channels, num_heatmaps, num_pafs)
    
  4. Refinement Stages (self.refinement_stages)

    self.refinement_stages = nn.ModuleList()
    for idx in range(num_refinement_stages):
        self.refinement_stages.append(RefinementStage(num_channels + num_heatmaps + num_pafs, num_channels,
                                                      num_heatmaps, num_pafs))
    

Forward Method

def forward(self, x):
    backbone_features = self.model(x)
    backbone_features = self.cpm(backbone_features)

    stages_output = self.initial_stage(backbone_features)
    for refinement_stage in self.refinement_stages:
        stages_output.extend(
            refinement_stage(torch.cat([backbone_features, stages_output[-2], stages_output[-1]], dim=1))
        )

    return stages_output