Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Unsupervised Learning and K-means

Up to this point, we’ve seen machine learning where the training data contained information about the right answer. Perceptrons, regression, and neural networks are examples of supervised learning, where each input is paired with the correct output. Reinforcement learning is considered semi-supervised; while we aren’t told exactly what the correct action to take is, we are given information about how good our past actions were.

There is a third kind of ML known as unsupervised learning. In this model, we are just given the input and we have to make sense of it somehow. Obviously, there are limitations on what we can do. Without a target yy, we cannot possibly come up with a regression function or policy. Usually, unsupervised learning ends up being an attempt to find a relationship between the individual pieces of data, typically in the form of clustering. In clustering, we divide the data into groups such that when a new piece of data comes along, we categorize it as being part of one of those clusters.

K-means

One of the simplest and most common clustering algorithms is known as “k-means.” The basic idea is that we maintain a set of kk points in the input space, which we call the “means.” Each input is then categorized or clustered into the group belonging to the closest mean.

Because our initial means are likely poor, we fix them by moving each mean to the actual mean of the points currently assigned to that cluster. Of course, this movement might change which points are closest to which means, so we re-classify everything. We repeat this process of classifying and adjusting the means until we reach a point where no mean moves at all.

While the definition of “closest” is flexible, we typically use the common Euclidean distance squared model:

ci=argminjxiμj2c_i = \arg\min_j ||x_i - \mu_j||^2

where μ\mu is the set of means. We then update the means using the following formula:

μj=1cjicjxi\mu_j = \frac{1}{|c_j|}\sum_{i \in c_j}x_i

This is literally just the average of the points currently in that cluster. We continue this cycle until the means stop moving.

Initialization and kk

How do we initialize the means? There are a few common ways:

  1. Pick random points in the space: This tends to work less well because the points could initially be very far from the actual data.

  2. Pick random samples of the input: This works better because we are likely to get initial means near where they should eventually be.

  3. Random assignment: For all the input, assign it to one of the classes randomly. The initial means are then the centroids of these random classes. This tends to set all the initial means near the center of the data.

Finally, we must address how to select the number kk. This is arguably the biggest weakness of k-means, as you must select kk before you start, and that value has a huge effect on the quality of the output. As with other parameters, we can try a variety of values and see which works best.

Example Walkthrough

(Image from James, Witten, Hastie, and Tibshirani. An Introduction to Statistical Learning.)

K-means Example Steps

In this image, the upper left shows the raw input data. This specific example initializes the means using the random assignment method, so all the data is initially assigned to a cluster (upper-middle). This results in all the initial means being located in the center of the data (upper right). We then reclassify all the points based on the closest mean (lower left), adjust the means to the centroid of each new cluster (lower middle), and the process concludes.

It is a very straightforward algorithm that is easy to implement in code. If you need to find some quick reasonable clusters from your data, try k-means first!