class RefinementStage(nn.Module):
def __init__(self, in_channels, out_channels, num_heatmaps, num_pafs):
super().__init__()
self.trunk = nn.Sequential(
RefinementStageBlock(in_channels, out_channels),
RefinementStageBlock(out_channels, out_channels),
RefinementStageBlock(out_channels, out_channels),
RefinementStageBlock(out_channels, out_channels),
RefinementStageBlock(out_channels, out_channels)
)
self.heatmaps = nn.Sequential(
conv(out_channels, out_channels, kernel_size=1, padding=0, bn=False),
conv(out_channels, num_heatmaps, kernel_size=1, padding=0, bn=False, relu=False)
)
self.pafs = nn.Sequential(
conv(out_channels, out_channels, kernel_size=1, padding=0, bn=False),
conv(out_channels, num_pafs, kernel_size=1, padding=0, bn=False, relu=False)
)
def forward(self, x):
trunk_features = self.trunk(x)
heatmaps = self.heatmaps(trunk_features)
pafs = self.pafs(trunk_features)
return [heatmaps, pafs]
The RefinementStage
class is a PyTorch module that refines the predictions (heatmaps and part affinity fields, or PAFs) from the previous stage of the pose estimation pipeline. It uses a series of RefinementStageBlock
s to process input features and outputs improved heatmaps and PAFs.
class RefinementStage(nn.Module):
def __init__(self, in_channels, out_channels, num_heatmaps, num_pafs):
super().__init__()
RefinementStage
iteratively improves the predictions (heatmaps and PAFs) by refining the input feature maps.in_channels
: Number of input channels in the feature map.out_channels
: Number of output channels in the intermediate feature maps.num_heatmaps
: Number of heatmaps to output (one for each keypoint type).num_pafs
: Number of PAFs to output (two for each connection type: x and y directions).self.trunk
self.trunk = nn.Sequential(
RefinementStageBlock(in_channels, out_channels),
RefinementStageBlock(out_channels, out_channels),
RefinementStageBlock(out_channels, out_channels),
RefinementStageBlock(out_channels, out_channels),
RefinementStageBlock(out_channels, out_channels)
)
RefinementStageBlock
s.self.heatmaps
self.heatmaps = nn.Sequential(
conv(out_channels, out_channels, kernel_size=1, padding=0, bn=False),
conv(out_channels, num_heatmaps, kernel_size=1, padding=0, bn=False, relu=False)
)
out_channels
using a 1x1 convolution.num_heatmaps
channels, without ReLU activation (relu=False
).self.pafs
self.pafs = nn.Sequential(
conv(out_channels, out_channels, kernel_size=1, padding=0, bn=False),
conv(out_channels, num_pafs, kernel_size=1, padding=0, bn=False, relu=False)
)
self.heatmaps
, but outputs num_pafs
channels for the PAFs.def forward(self, x):
trunk_features = self.trunk(x)
heatmaps = self.heatmaps(trunk_features)
pafs = self.pafs(trunk_features)
return [heatmaps, pafs]
Trunk Features:
trunk_features = self.trunk(x)
x
) is processed by the self.trunk
(a sequence of RefinementStageBlock
s) to extract refined features.Heatmaps:
heatmaps = self.heatmaps(trunk_features)
self.heatmaps
to generate the heatmaps.PAFs:
pafs = self.pafs(trunk_features)
self.pafs
to generate the PAFs.Output:
return [heatmaps, pafs]