In K-Means clustering, one of the first steps is to assign each data point to the
closest centroid. This step is handled by a function often called
find_closest_centroids
.
Let’s walk through how to implement it step by step with NumPy.
๐ Problem Setup
We are given:
X
: an array of data points (shape(m, n)
).centroids
: an array of centroid locations (shape(K, n)
).
X[i]
, the index of the centroid closest to it.
๐ ️ Step-by-Step Implementation
Here’s the full implementation using a simple for-loop and
np.linalg.norm
to compute distances:
# UNQ_C1
# GRADED FUNCTION: find_closest_centroids
def find_closest_centroids(X, centroids):
"""
Computes the centroid memberships for every example
Args:
X (ndarray): (m, n) Input values
centroids (ndarray): (K, n) centroids
Returns:
idx (array_like): (m,) closest centroids
"""
# Set K
K = centroids.shape[0]
# You need to return the following variables correctly
idx = np.zeros(X.shape[0], dtype=int)
### START CODE HERE ###
for i in range(X.shape[0]):
# Array to hold distance between X[i] and each centroids[j]
distance = []
for j in range(centroids.shape[0]):
norm_ij = np.linalg.norm(X[i] - centroids[j]) # distance between point and centroid
distance.append(norm_ij)
idx[i] = np.argmin(distance) # index of closest centroid
### END CODE HERE ###
return idx
๐งช Testing the Function
Let’s test our implementation on a small dataset:
# Load an example dataset
X = load_data()
print("First five elements of X are:\n", X[:5])
print('The shape of X is:', X.shape)
# Select an initial set of centroids (3 Centroids)
initial_centroids = np.array([[3,3], [6,2], [8,5]])
# Find closest centroids
idx = find_closest_centroids(X, initial_centroids)
# Print closest centroids for the first three elements
print("First three elements in idx are:", idx[:3])
✅ Expected Output
First three elements in idx are: [0 2 1]
๐ก Key Insights
- We loop over every point
X[i]
and compute its distance to every centroidcentroids[j]
. - We store the distances in a list and then use
np.argmin
to find the closest centroid. - The result
idx
is an array of integers representing the centroid index for each data point.
⚡ Bonus: Vectorized Version
Although the loop-based approach is clear, NumPy allows us to write a faster vectorized version that avoids explicit loops:
def find_closest_centroids_vectorized(X, centroids):
# Compute distance from each point to each centroid
distances = np.linalg.norm(X[:, np.newaxis] - centroids, axis=2)
# Pick the index of the closest centroid
idx = np.argmin(distances, axis=1)
return idx
๐ฏ Conclusion
Implementing find_closest_centroids
is a key building block for understanding
K-Means clustering. Starting with a loop-based approach builds intuition,
while the vectorized version offers performance gains.
Mastering both will strengthen your foundation in machine learning and NumPy programming!