class RefinementStageBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.initial = conv(in_channels, out_channels, kernel_size=1, padding=0, bn=False)
self.trunk = nn.Sequential(
conv(out_channels, out_channels),
conv(out_channels, out_channels, dilation=2, padding=2)
)
def forward(self, x):
initial_features = self.initial(x)
trunk_features = self.trunk(initial_features)
return initial_features + trunk_features
The RefinementStageBlock
class is a PyTorch module that serves as a building block for the refinement stages in the pose estimation model. It processes input feature maps and refines them using a combination of standard and dilated convolutions, with a residual connection to preserve input information.
class RefinementStageBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
RefinementStageBlock
refines feature maps by applying convolutions and adding a residual connection to improve gradient flow and preserve input features.in_channels
: Number of input channels in the feature map.out_channels
: Number of output channels in the feature map.self.initial
self.initial = conv(in_channels, out_channels, kernel_size=1, padding=0, bn=False)
bn=False
).in_channels
) to match the desired number of output channels (out_channels
).self.trunk
self.trunk = nn.Sequential(
conv(out_channels, out_channels),
conv(out_channels, out_channels, dilation=2, padding=2)
)
def forward(self, x):
initial_features = self.initial(x)
trunk_features = self.trunk(initial_features)
return initial_features + trunk_features
Initial Features:
initial_features = self.initial(x)
x
) is passed through the self.initial
layer to align the number of channels.Trunk Features:
trunk_features = self.trunk(initial_features)
self.trunk
(a combination of standard and dilated convolutions) to extract refined features.Residual Connection:
return initial_features + trunk_features
self.trunk
is added element-wise to the initial_features
(residual connection).RefinementStageBlock
is used in the RefinementStage
class, which consists of multiple such blocks.