class PoseEstimationWithMobileNet(nn.Module):
def __init__(self, num_refinement_stages=1, num_channels=128, num_heatmaps=19, num_pafs=38):
super().__init__()
self.model = nn.Sequential(
conv( 3, 32, stride=2, bias=False),
conv_dw( 32, 64),
conv_dw( 64, 128, stride=2),
conv_dw(128, 128),
conv_dw(128, 256, stride=2),
conv_dw(256, 256),
conv_dw(256, 512), # conv4_2
conv_dw(512, 512, dilation=2, padding=2),
conv_dw(512, 512),
conv_dw(512, 512),
conv_dw(512, 512),
conv_dw(512, 512) # conv5_5
)
self.cpm = Cpm(512, num_channels)
self.initial_stage = InitialStage(num_channels, num_heatmaps, num_pafs)
self.refinement_stages = nn.ModuleList()
for idx in range(num_refinement_stages):
self.refinement_stages.append(RefinementStage(num_channels + num_heatmaps + num_pafs, num_channels,
num_heatmaps, num_pafs))
def forward(self, x):
backbone_features = self.model(x)
backbone_features = self.cpm(backbone_features)
stages_output = self.initial_stage(backbone_features)
for refinement_stage in self.refinement_stages:
stages_output.extend(
refinement_stage(torch.cat([backbone_features, stages_output[-2], stages_output[-1]], dim=1)))
return stages_output
The PoseEstimationWithMobileNet
class is the main model for lightweight human pose estimation. It combines a MobileNet-based backbone for feature extraction with additional stages for keypoint detection and refinement. The model outputs heatmaps (for keypoint locations) and part affinity fields (PAFs, for connections between keypoints).
class PoseEstimationWithMobileNet(nn.Module):
def __init__(self, num_refinement_stages=1, num_channels=128, num_heatmaps=19, num_pafs=38):
super().__init__()
num_refinement_stages
: Number of refinement stages to iteratively improve predictions.num_channels
: Number of channels in the intermediate feature maps.num_heatmaps
: Number of heatmaps to output (one for each keypoint type).num_pafs
: Number of PAFs to output (two for each connection type: x and y directions).Backbone (self.model
)
self.model = nn.Sequential(
conv( 3, 32, stride=2, bias=False),
conv_dw( 32, 64),
conv_dw( 64, 128, stride=2),
conv_dw(128, 128),
conv_dw(128, 256, stride=2),
conv_dw(256, 256),
conv_dw(256, 512), # conv4_2
conv_dw(512, 512, dilation=2, padding=2),
conv_dw(512, 512),
conv_dw(512, 512),
conv_dw(512, 512),
conv_dw(512, 512) # conv5_5
)
conv
and conv_dw
).Convolutional Pose Machine (self.cpm
)
self.cpm = Cpm(512, num_channels)
Cpm
module that processes the backbone's output and reduces it to num_channels
.Initial Stage (self.initial_stage
)
self.initial_stage = InitialStage(num_channels, num_heatmaps, num_pafs)
InitialStage
module that generates the first set of heatmaps and PAFs.Refinement Stages (self.refinement_stages
)
self.refinement_stages = nn.ModuleList()
for idx in range(num_refinement_stages):
self.refinement_stages.append(RefinementStage(num_channels + num_heatmaps + num_pafs, num_channels,
num_heatmaps, num_pafs))
RefinementStage
modules.def forward(self, x):
backbone_features = self.model(x)
backbone_features = self.cpm(backbone_features)
stages_output = self.initial_stage(backbone_features)
for refinement_stage in self.refinement_stages:
stages_output.extend(
refinement_stage(torch.cat([backbone_features, stages_output[-2], stages_output[-1]], dim=1))
)
return stages_output
Backbone Features:
backbone_features = self.model(x)
backbone_features = self.cpm(backbone_features)
x
) is passed through the backbone and the CPM module to extract feature maps.Initial Predictions:
stages_output = self.initial_stage(backbone_features)
InitialStage
to generate the first set of heatmaps and PAFs.Refinement Stages:
for refinement_stage in self.refinement_stages:
stages_output.extend(
refinement_stage(torch.cat([backbone_features, stages_output[-2], stages_output[-1]], dim=1))
)
stages_output
.Output:
return stages_output