class InitialStage(nn.Module):
def __init__(self, num_channels, num_heatmaps, num_pafs):
super().__init__()
self.trunk = nn.Sequential(
conv(num_channels, num_channels, bn=False),
conv(num_channels, num_channels, bn=False),
conv(num_channels, num_channels, bn=False)
)
self.heatmaps = nn.Sequential(
conv(num_channels, 512, kernel_size=1, padding=0, bn=False),
conv(512, num_heatmaps, kernel_size=1, padding=0, bn=False, relu=False)
)
self.pafs = nn.Sequential(
conv(num_channels, 512, kernel_size=1, padding=0, bn=False),
conv(512, num_pafs, kernel_size=1, padding=0, bn=False, relu=False)
)
def forward(self, x):
trunk_features = self.trunk(x)
heatmaps = self.heatmaps(trunk_features)
pafs = self.pafs(trunk_features)
return [heatmaps, pafs]
The InitialStage class is a PyTorch module that represents the first stage of the pose estimation pipeline. It processes feature maps from the backbone network (or the CPM module) and generates two outputs: heatmaps and part affinity fields (PAFs). These outputs are essential for detecting keypoints and their connections in human pose estimation.
class InitialStage(nn.Module):
def __init__(self, num_channels, num_heatmaps, num_pafs):
super().__init__()
InitialStage processes input feature maps and produces:
num_channels: Number of input channels in the feature map.num_heatmaps: Number of heatmaps to output (one for each keypoint type).num_pafs: Number of PAFs to output (two for each connection type: x and y directions).self.trunk
self.trunk = nn.Sequential(
conv(num_channels, num_channels, bn=False),
conv(num_channels, num_channels, bn=False),
conv(num_channels, num_channels, bn=False)
)
conv), each with:
bn=False).conv).self.heatmaps
self.heatmaps = nn.Sequential(
conv(num_channels, 512, kernel_size=1, padding=0, bn=False),
conv(512, num_heatmaps, kernel_size=1, padding=0, bn=False, relu=False)
)
num_heatmaps channels, without ReLU activation (relu=False).self.pafs
self.pafs = nn.Sequential(
conv(num_channels, 512, kernel_size=1, padding=0, bn=False),
conv(512, num_pafs, kernel_size=1, padding=0, bn=False, relu=False)
)
self.heatmaps, but outputs num_pafs channels for the PAFs.def forward(self, x):
trunk_features = self.trunk(x)
heatmaps = self.heatmaps(trunk_features)
pafs = self.pafs(trunk_features)
return [heatmaps, pafs]
Trunk Features:
trunk_features = self.trunk(x)
x) is processed by the self.trunk to extract refined features.Heatmaps:
heatmaps = self.heatmaps(trunk_features)
self.heatmaps to generate the heatmaps.PAFs:
pafs = self.pafs(trunk_features)
self.pafs to generate the PAFs.Output:
return [heatmaps, pafs]