The l2_loss
function implements a masked mean squared error (MSE) loss, which is commonly used in regression tasks. It calculates the squared difference between the predicted values (input
) and the ground truth (target
), applies a mask to focus on specific regions, and normalizes the loss by the batch size.
def l2_loss(input, target, mask, batch_size):
loss = (input - target) * mask
loss = (loss * loss) / 2 / batch_size
return loss.sum()
input
:
[batch_size, channels, height, width]
.target
:
input
.mask
:
input
.batch_size
:
Compute the Difference:
loss = (input - target) * mask
input
) and the ground truth (target
).mask
to focus only on the relevant regions.Square the Difference:
loss = (loss * loss) / 2 / batch_size
2
to match the mathematical definition of the L2 loss:$$ L2Loss = (1/2)(x-y)^2 $$
batch_size
to ensure the loss is independent of the batch size.Sum the Loss:
return loss.sum()
The l2_loss
function is designed for tasks like human pose estimation, where: