Clustering in Ruby
Clustering algorithms play a very important role in many modern web applications that feature machine learning. This article will introduce you to one of the simplest techniques for the unsupervised grouping of related objects, or clustering.
If you’re at all interested in automated data grouping or sorting you should at least be familiar with one or two types of these algorithms (in addition to more advanced Machine Learning topics, but this serves a good start). In this article, I’m going to go through the process of implementing a very simple Ruby program that can group a set of 2D coordinates into clusters, where each group is composed of a center point, and all of the data points closest to it.
The algorithm that we’ll be using to accomplish this task is a simple one; k-means clustering will give us the behavior that we want with little fuss.
Our k-means clustering algorithm takes in, as input, a set of points in the 2-dimensional plane. As output, the points will be grouped into k clusters, where k is an integer specified by the user. Unfortunately, the algorithm can’t decide how many groups there are by itself without more complication, so k must be given. Nonetheless, I’ll describe the algorithm below:
- Start by choosing k random points within the dataset’s range as an initial guess for the positions of all the clusters. These points form the centroid point of all the clusters. All distances to other points will be measured from here.
- For each point in the input data, assign it to the cluster that it is nearest to. After this step, each cluster will somehow be associated with a set of nearby points.
- For each cluster, go through the set of associated datapoints and calculate the average among them. This will give a new centroid point that is directly in the center of all of the member points.
- If the clusters didn’t move from their previous locations after recentering, or if they all move less than a certain delta value, return the k clusters and their associated points. Otherwise, go back to Step 2 after deassociating all of the associated points with their cluster. This lets the algorithm start fresh, but with more accurate centroid points.
To begin with, we’ll need a class to store the points to be clustered by the algorithm. Essentially, we just need a
Point class to hold
y values. This is implemented below (I won’t insult your intelligence trying to explain it):
Next, there has to be a class to hold clusters of data. As the algorithm described, clusters have groups of member points and a center point (not necessarily in the dataset) associated with them. This corresponds to two instance variables:
@points, a list of
@center, a single
Point. Additionally, there needs to be a way for
Clusters to update by averaging their member points. This is implemented in
Finally, the algorithm itself needs to be implemented.
The parameters to the
kmeans function are a dataset (list of
Points), data, number of clusters to find, k, and an optional halting delta, delta. The algorithm will halt when all of the clusters are updated by a value less than delta on an iteration.
Initially, the algorithm needs to choose the starting guesses for cluster centers. It does this by generating
Cluster objects, and assigning them a center from a randomly selected
Point from data.
Next is the main meat of the algorithm. The code loops indefinitely and assigns points to clusters by finding, for each point, which cluster center is the closest. This assignment will be updated, and become more accurate, each iteration of the loop while the clusters recenter.
Finally, in the code at the bottom of the while loop, we recalculate the centers of the clusters for the next iteration. This is done by calling
recenter! on all of the
Cluster objects. Additionally, we do some delta checking because we need to leave the loop eventually. By keeping track of the most that any
Cluster was updated, we can compare it against delta to see if all of the
Clusters were below the input
delta. If the delta was hit, the algorithm terminates, returning a list of all of the
Clusters found in the dataset.
Overall, k-means clustering is a pretty simple algorithm, as you can see from above. The entire source file, along with glue/integration code, is available here to download and/or view. Next, let’s see the program in action.
As you can see from the link above, I ended up writing some additional shell code to implement reading in data, getting
k, and plotting the output of the algorithm. For the plotting, I used
gem install gnuplot) to pipe commands to a running instance of Gnuplot.
To run the full program, open up a terminal and execute:
$ ruby kmeans.rb CSVFILE
The program assumes
CSVFILE has two comma-separated floating point numbers per line, specifying both the x and y coordinates of a single point.
To give you a feel for the output that k-means generates, I ran it on a random dataset. In the graph below, each set of points plotted with the same color indicates a
Cluster object’s member points. To run this yourself, you can grab my dataset here, although creating your own with more definite, pre-designed clusters may be more interesting to see.
While this is a very simple example, note that the x and y axis can be whatever you want them to be (latitude/longitude of households, baseball stats, etc). You could even (easily) extend the program to support 3 or more parameters (dimensions). Thus, k-means clustering can actually be a powerful tool for grouping real-world datasets, despite the apparent simplicity.