๐Ÿš€ Understanding find_closest_centroids in K-Means Clustering


In K-Means clustering, one of the first steps is to assign each data point to the closest centroid. This step is handled by a function often called find_closest_centroids. Let’s walk through how to implement it step by step with NumPy.

๐Ÿ”Ž Problem Setup

We are given:

  • X: an array of data points (shape (m, n)).
  • centroids: an array of centroid locations (shape (K, n)).
The task is to compute, for each example X[i], the index of the centroid closest to it.

๐Ÿ› ️ Step-by-Step Implementation

Here’s the full implementation using a simple for-loop and np.linalg.norm to compute distances:


# UNQ_C1
# GRADED FUNCTION: find_closest_centroids

def find_closest_centroids(X, centroids):
    """
    Computes the centroid memberships for every example
    
    Args:
        X (ndarray): (m, n) Input values      
        centroids (ndarray): (K, n) centroids
    
    Returns:
        idx (array_like): (m,) closest centroids
    """

    # Set K
    K = centroids.shape[0]

    # You need to return the following variables correctly
    idx = np.zeros(X.shape[0], dtype=int)

    ### START CODE HERE ###
    for i in range(X.shape[0]):
        # Array to hold distance between X[i] and each centroids[j]
        distance = [] 
        for j in range(centroids.shape[0]):
            norm_ij = np.linalg.norm(X[i] - centroids[j])   # distance between point and centroid
            distance.append(norm_ij)

        idx[i] = np.argmin(distance)   # index of closest centroid
    ### END CODE HERE ###
    
    return idx
  

๐Ÿงช Testing the Function

Let’s test our implementation on a small dataset:


# Load an example dataset
X = load_data()

print("First five elements of X are:\n", X[:5]) 
print('The shape of X is:', X.shape)

# Select an initial set of centroids (3 Centroids)
initial_centroids = np.array([[3,3], [6,2], [8,5]])

# Find closest centroids
idx = find_closest_centroids(X, initial_centroids)

# Print closest centroids for the first three elements
print("First three elements in idx are:", idx[:3])
  

✅ Expected Output

First three elements in idx are: [0 2 1]
  

๐Ÿ’ก Key Insights

  • We loop over every point X[i] and compute its distance to every centroid centroids[j].
  • We store the distances in a list and then use np.argmin to find the closest centroid.
  • The result idx is an array of integers representing the centroid index for each data point.

⚡ Bonus: Vectorized Version

Although the loop-based approach is clear, NumPy allows us to write a faster vectorized version that avoids explicit loops:


def find_closest_centroids_vectorized(X, centroids):
    # Compute distance from each point to each centroid
    distances = np.linalg.norm(X[:, np.newaxis] - centroids, axis=2)
    # Pick the index of the closest centroid
    idx = np.argmin(distances, axis=1)
    return idx
  

๐ŸŽฏ Conclusion

Implementing find_closest_centroids is a key building block for understanding K-Means clustering. Starting with a loop-based approach builds intuition, while the vectorized version offers performance gains. Mastering both will strengthen your foundation in machine learning and NumPy programming!



Post a Comment

Previous Post Next Post