Project Code

def connections_nms(a_idx, b_idx, affinity_scores):
    # From all retrieved connections that share the same starting/ending keypoints leave only the top-scoring ones.
    order = affinity_scores.argsort()[::-1]
    affinity_scores = affinity_scores[order]
    a_idx = a_idx[order]
    b_idx = b_idx[order]
    idx = []
    has_kpt_a = set()
    has_kpt_b = set()
    for t, (i, j) in enumerate(zip(a_idx, b_idx)):
        if i not in has_kpt_a and j not in has_kpt_b:
            idx.append(t)
            has_kpt_a.add(i)
            has_kpt_b.add(j)
    idx = np.asarray(idx, dtype=np.int32)
    return a_idx[idx], b_idx[idx], affinity_scores[idx]

The connections_nms function performs Non-Maximum Suppression (NMS) on a set of candidate connections between keypoints. It ensures that only the highest-scoring connections are retained, while removing overlapping or conflicting connections that share the same starting or ending keypoints.


Function Definition

def connections_nms(a_idx, b_idx, affinity_scores):
    # From all retrieved connections that share the same starting/ending keypoints leave only the top-scoring ones.

Step-by-Step Explanation

  1. Sort Connections by Affinity Scores:

    order = affinity_scores.argsort()[::-1]
    affinity_scores = affinity_scores[order]
    a_idx = a_idx[order]
    b_idx = b_idx[order]
    
  2. Initialize Tracking Structures:

    idx = []
    has_kpt_a = set()
    has_kpt_b = set()
    
  3. Iterate Through Connections:

    for t, (i, j) in enumerate(zip(a_idx, b_idx)):
        if i not in has_kpt_a and j not in has_kpt_b:
            idx.append(t)
            has_kpt_a.add(i)
            has_kpt_b.add(j)
    
  4. Filter Connections:

    idx = np.asarray(idx, dtype=np.int32)
    return a_idx[idx], b_idx[idx], affinity_scores[idx]
    

Purpose

The connections_nms function ensures that:

  1. Each starting keypoint (a_idx) is connected to at most one ending keypoint (b_idx).
  2. Each ending keypoint (b_idx) is connected to at most one starting keypoint (a_idx).