$k$-means is one of the simplest algorithms for finding clusters in a dataset. “Cluster” is not a particularly well-defined concept, but the general idea is that some elements of a dataset are more similar to each other than they are to other elements – they form a cluster. By calculating which points belong to which clusters, we can then try to understand the datasets by understanding the clusters, which can represent a significant reduction in complexity.

Optimization formulation for $k$-means

Although $k$-means is usually described operationally (“find cluster centers, find assignments, repeat”), there’s a cleaner way to describe it in terms of an optimization criterion.

Means are minimizers

The mean of a set of vectors is another object that is often described operationally (“add the vectors, divide by the count”), but not as commonly described as the best object from some candidate set. It turns out that the mean is the minimizer of a very natural function: the sum of squared distances. In other words, the mean of a set of vectors is the vector which minimizes the sum of squared distances from itself to the vectors under consideration. Given a set of $n$ vectors ${ v_0, \ldots, v_{n-1} }$, we define the error function to minimize as

\[E(c) = (1/2) \sum_i ||c - v_i||^2\]

Taking the gradient with respect to $c$ directly gives the answer:

\[\nabla E(c) = \sum_i (c - v_i) = 0\] \[\hat{c} = \frac{\sum_i v_i}{n}\]

We use the hat notation $\hat{c}$ in analogy to linear least squares estimators like $\hat{\beta}$ to highlight the fact that the mean of a set of points can be seen as the linear least squares estimates for this set of points under a model consisting only of constant functions.


The optimal solution for $k$-means finds the best set of $k$ “cluster centers” and the best assignment of input points to cluster centers, where “best” is defined as minimizing the sum of squared distances from centers to vectors assigned to it.

Let’s first set up some notation. We let the variable $j$ range over cluster indices, from $1$ to $k$. The variable $i$ will range over data points, from $1$ to $n$. We will use the $a_i$ to mean the assignment of point $i$. If $k = 3$, then $a_i \in \{1, 2, 3\}$. In addition, we will use $c_j$ to mean the center of cluster $j$. A potential solution of the $k$-means problem is then some assignment $a$, and some cluster centers $c$. The error of any given solution is the sum of the squared distances from each point to the center of the cluster they’re assigned to:

\[E(c, a) = \sum_{i=1}^n \sum_{j=1}^k || v_i - c_{a_i} || ^2\]

This formulation does not have a closed-form solution. What this means is that we need an actual algorithm to solve it.

Alternating Optimization

$k$-means is the quintessential “alternating optimization” algorithm: if a formulation is hard to solve at its entirety, it’s often easier to solve it in steps. In the case of $k$-means, if we have a guess for the centers, then finding the best assignment is easy: we simply iterate over all pairs of data points and centers and compute the best assignment exhaustively. And if we have a guess of assignments, finding the best centers for those assignments is also easy: it’s just the mean.

But does this algorithm terminate? And does it give an optimal solution?



Additional reading

The choice of initialization for $k$-means can greatly affect how fast it converges (and good the results are). $k$-means++ offers a simple rule for initialization that has provable approximation guarantees.