def connections_nms(a_idx, b_idx, affinity_scores):
# From all retrieved connections that share the same starting/ending keypoints leave only the top-scoring ones.
order = affinity_scores.argsort()[::-1]
affinity_scores = affinity_scores[order]
a_idx = a_idx[order]
b_idx = b_idx[order]
idx = []
has_kpt_a = set()
has_kpt_b = set()
for t, (i, j) in enumerate(zip(a_idx, b_idx)):
if i not in has_kpt_a and j not in has_kpt_b:
idx.append(t)
has_kpt_a.add(i)
has_kpt_b.add(j)
idx = np.asarray(idx, dtype=np.int32)
return a_idx[idx], b_idx[idx], affinity_scores[idx]
The connections_nms
function performs Non-Maximum Suppression (NMS) on a set of candidate connections between keypoints. It ensures that only the highest-scoring connections are retained, while removing overlapping or conflicting connections that share the same starting or ending keypoints.
def connections_nms(a_idx, b_idx, affinity_scores):
# From all retrieved connections that share the same starting/ending keypoints leave only the top-scoring ones.
a_idx
: An array of indices representing the starting keypoints of candidate connections.b_idx
: An array of indices representing the ending keypoints of candidate connections.affinity_scores
: An array of scores representing the strength of the connection between the keypoints (e.g., derived from the Part Affinity Fields, or PAFs).a_idx
, b_idx
, and affinity_scores
arrays, containing only the top-scoring, non-overlapping connections.Sort Connections by Affinity Scores:
order = affinity_scores.argsort()[::-1]
affinity_scores = affinity_scores[order]
a_idx = a_idx[order]
b_idx = b_idx[order]
affinity_scores
.Initialize Tracking Structures:
idx = []
has_kpt_a = set()
has_kpt_b = set()
idx
: A list to store the indices of the selected connections.has_kpt_a
: A set to track which starting keypoints (a_idx
) have already been used in selected connections.has_kpt_b
: A set to track which ending keypoints (b_idx
) have already been used in selected connections.Iterate Through Connections:
for t, (i, j) in enumerate(zip(a_idx, b_idx)):
if i not in has_kpt_a and j not in has_kpt_b:
idx.append(t)
has_kpt_a.add(i)
has_kpt_b.add(j)
(i, j)
:
i
is not already in has_kpt_a
and the ending keypoint j
is not already in has_kpt_b
.t
to idx
.i
as used by adding it to has_kpt_a
.j
as used by adding it to has_kpt_b
.Filter Connections:
idx = np.asarray(idx, dtype=np.int32)
return a_idx[idx], b_idx[idx], affinity_scores[idx]
idx
) to a NumPy array.a_idx
, b_idx
, and affinity_scores
arrays, retaining only the top-scoring, non-overlapping connections.The connections_nms
function ensures that:
a_idx
) is connected to at most one ending keypoint (b_idx
).b_idx
) is connected to at most one starting keypoint (a_idx
).