class InitialStage(nn.Module):
def __init__(self, num_channels, num_heatmaps, num_pafs):
super().__init__()
self.trunk = nn.Sequential(
conv(num_channels, num_channels, bn=False),
conv(num_channels, num_channels, bn=False),
conv(num_channels, num_channels, bn=False)
)
self.heatmaps = nn.Sequential(
conv(num_channels, 512, kernel_size=1, padding=0, bn=False),
conv(512, num_heatmaps, kernel_size=1, padding=0, bn=False, relu=False)
)
self.pafs = nn.Sequential(
conv(num_channels, 512, kernel_size=1, padding=0, bn=False),
conv(512, num_pafs, kernel_size=1, padding=0, bn=False, relu=False)
)
def forward(self, x):
trunk_features = self.trunk(x)
heatmaps = self.heatmaps(trunk_features)
pafs = self.pafs(trunk_features)
return [heatmaps, pafs]
The InitialStage
class is a PyTorch module that represents the first stage of the pose estimation pipeline. It processes feature maps from the backbone network (or the CPM module) and generates two outputs: heatmaps and part affinity fields (PAFs). These outputs are essential for detecting keypoints and their connections in human pose estimation.
class InitialStage(nn.Module):
def __init__(self, num_channels, num_heatmaps, num_pafs):
super().__init__()
InitialStage
processes input feature maps and produces:
num_channels
: Number of input channels in the feature map.num_heatmaps
: Number of heatmaps to output (one for each keypoint type).num_pafs
: Number of PAFs to output (two for each connection type: x and y directions).self.trunk
self.trunk = nn.Sequential(
conv(num_channels, num_channels, bn=False),
conv(num_channels, num_channels, bn=False),
conv(num_channels, num_channels, bn=False)
)
conv
), each with:
bn=False
).conv
).self.heatmaps
self.heatmaps = nn.Sequential(
conv(num_channels, 512, kernel_size=1, padding=0, bn=False),
conv(512, num_heatmaps, kernel_size=1, padding=0, bn=False, relu=False)
)
num_heatmaps
channels, without ReLU activation (relu=False
).self.pafs
self.pafs = nn.Sequential(
conv(num_channels, 512, kernel_size=1, padding=0, bn=False),
conv(512, num_pafs, kernel_size=1, padding=0, bn=False, relu=False)
)
self.heatmaps
, but outputs num_pafs
channels for the PAFs.def forward(self, x):
trunk_features = self.trunk(x)
heatmaps = self.heatmaps(trunk_features)
pafs = self.pafs(trunk_features)
return [heatmaps, pafs]
Trunk Features:
trunk_features = self.trunk(x)
x
) is processed by the self.trunk
to extract refined features.Heatmaps:
heatmaps = self.heatmaps(trunk_features)
self.heatmaps
to generate the heatmaps.PAFs:
pafs = self.pafs(trunk_features)
self.pafs
to generate the PAFs.Output:
return [heatmaps, pafs]