import matplotlib.pyplot as plt
import seaborn as sns; sns.set() # for plot styling
import numpy as np
K-Means Example
This notebook is based on the Python Data Science Handbook by Jake VanderPlas; the content is available on GitHub. The text is released under the CC-BY-NC-ND license, and code is released under the MIT license.
k-Means Clustering
from sklearn.datasets import make_blobs
= make_blobs(n_samples=300, centers=4,
X, y_true =0.60, random_state=0)
cluster_std0], X[:, 1], s=50); plt.scatter(X[:,
from sklearn.cluster import KMeans
= KMeans(n_clusters=4)
kmeans
kmeans.fit(X)= kmeans.predict(X) y_kmeans
0], X[:, 1], c=y_kmeans, s=50, cmap='viridis')
plt.scatter(X[:,
= kmeans.cluster_centers_
centers 0], centers[:, 1], c='black', s=200, alpha=0.5); plt.scatter(centers[:,
k-Means Algorithm: Expectation–Maximization
from sklearn.metrics import pairwise_distances_argmin
def find_clusters(X, n_clusters, rseed=2):
# 1. Randomly choose clusters
= np.random.RandomState(rseed)
rng = rng.permutation(X.shape[0])[:n_clusters]
i = X[i]
centers
while True:
# 2a. Assign labels based on closest center
= pairwise_distances_argmin(X, centers)
labels
# 2b. Find new centers from means of points
= np.array([X[labels == i].mean(0)
new_centers for i in range(n_clusters)])
# 2c. Check for convergence
if np.all(centers == new_centers):
break
= new_centers
centers
return centers, labels
= find_clusters(X, 4)
centers, labels 0], X[:, 1], c=labels,
plt.scatter(X[:, =50, cmap='viridis'); s
Caveats of expectation–maximization
There are a few issues to be aware of when using the expectation–maximization algorithm.
The globally optimal result may not be achieved
First, although the E–M procedure is guaranteed to improve the result in each step, there is no assurance that it will lead to the global best solution. For example, if we use a different random seed in our simple procedure, the particular starting guesses lead to poor results:
= find_clusters(X, 4, rseed=0)
centers, labels 0], X[:, 1], c=labels,
plt.scatter(X[:, =50, cmap='viridis'); s
Here the E–M approach has converged, but has not converged to a globally optimal configuration. For this reason, it is common for the algorithm to be run for multiple starting guesses, as indeed Scikit-Learn does by default (set by the n_init
parameter, which defaults to 10).
The number of clusters must be selected beforehand
Another common challenge with k-means is that you must tell it how many clusters you expect: it cannot learn the number of clusters from the data. For example, if we ask the algorithm to identify six clusters, it will happily proceed and find the best six clusters:
= KMeans(6, random_state=0).fit_predict(X)
labels 0], X[:, 1], c=labels,
plt.scatter(X[:, =50, cmap='viridis'); s
Whether the result is meaningful is a question that is difficult to answer definitively; one approach that is rather intuitive, but that we won’t discuss further here, is called silhouette analysis.
Alternatively, you might use a more complicated clustering algorithm which has a better quantitative measure of the fitness per number of clusters (e.g., Gaussian mixture models; see In Depth: Gaussian Mixture Models) or which can choose a suitable number of clusters (e.g., DBSCAN, mean-shift, or affinity propagation, all available in the sklearn.cluster
submodule)
k-means is limited to linear cluster boundaries
The fundamental model assumptions of k-means (points will be closer to their own cluster center than to others) means that the algorithm will often be ineffective if the clusters have complicated geometries.
In particular, the boundaries between k-means clusters will always be linear, which means that it will fail for more complicated boundaries. Consider the following data, along with the cluster labels found by the typical k-means approach:
from sklearn.datasets import make_moons
= make_moons(200, noise=.05, random_state=0) X, y
= KMeans(2, random_state=0).fit_predict(X)
labels 0], X[:, 1], c=labels,
plt.scatter(X[:, =50, cmap='viridis'); s