class RefinementStageBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.initial = conv(in_channels, out_channels, kernel_size=1, padding=0, bn=False)
self.trunk = nn.Sequential(
conv(out_channels, out_channels),
conv(out_channels, out_channels, dilation=2, padding=2)
)
def forward(self, x):
initial_features = self.initial(x)
trunk_features = self.trunk(initial_features)
return initial_features + trunk_features
The RefinementStageBlock class is a PyTorch module that serves as a building block for the refinement stages in the pose estimation model. It processes input feature maps and refines them using a combination of standard and dilated convolutions, with a residual connection to preserve input information.
class RefinementStageBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
RefinementStageBlock refines feature maps by applying convolutions and adding a residual connection to improve gradient flow and preserve input features.in_channels: Number of input channels in the feature map.out_channels: Number of output channels in the feature map.self.initial
self.initial = conv(in_channels, out_channels, kernel_size=1, padding=0, bn=False)
bn=False).in_channels) to match the desired number of output channels (out_channels).self.trunk
self.trunk = nn.Sequential(
conv(out_channels, out_channels),
conv(out_channels, out_channels, dilation=2, padding=2)
)
def forward(self, x):
initial_features = self.initial(x)
trunk_features = self.trunk(initial_features)
return initial_features + trunk_features
Initial Features:
initial_features = self.initial(x)
x) is passed through the self.initial layer to align the number of channels.Trunk Features:
trunk_features = self.trunk(initial_features)
self.trunk (a combination of standard and dilated convolutions) to extract refined features.Residual Connection:
return initial_features + trunk_features
self.trunk is added element-wise to the initial_features (residual connection).RefinementStageBlock is used in the RefinementStage class, which consists of multiple such blocks.