k均值算法主要应用在聚类上的无监督算法。
他的目标是将数据集中的样本根据其特征分为几类,使得每一类内部样本的特征都尽可能相近,这样的任务通常被称为聚类任务,作为最简单的聚类算法,k均值算法在现实中有广泛的应用。
原理
假设空间中有一些样本,聚类问题的目标就是将这些样本按距离划分成数个类。设数据集D=\{ x_{1},...,x_{M} \}
,其中每个样本x_{i}\in R^{n}
的特征维数都是n
,最终聚类的个数K由我们提前指定。直观上说,同一类的样本之间距离应该比不同类的样本之间的距离近。但是,由于我们没有任何点的真实标签,所以也无法在最开始确定每一类的中心(centroid),以其为基准计算距离并分类。针对这一问题,k均值算法提出了一个非常简单的解决方案:在初始时,随机选取数据集中的K个样本\mu _{1},...,\mu _{k}
,将\mu _{i}
作为第i
类的中心。选取中心后,我们用最简单的方式,把数据集中的样本归到最近的中心点所代表的类中。记第i
类包含样本的集合为C_{i}
,两样本之间的距离函数为d
,那么C_{i}
可以写为:
C_{i} = \{ x_{j}\in D|\forall l\ne i,d(x_{j},\mu _{i}) \le d(x_{j},\mu _{l})\}
这里的中心点是我们随机选取的,所以还需要优化,尽可能减小类内的样本到中心点的距离。将数据集中所有样本到其对应中心距离之和作为损失函数,得到:
L(C_{1},...,C_{k})=\sum_{i=1}^{k} \sum_{x\in C_{i}}d(x,\mu _{i})=\sum_{i=1}^{K} \sum_{j=1}^{M}Ⅱ(x_{j}\in C_{i})d(x_{j},\mu _{i})
既然在初始时,各个类的中心点\mu _{i}
是随机选取的,那么我们应当再选取新的中心点,使得损失函数的值最小。将上式对\mu _{i}
求偏导,得到:
\frac{\partial L}{\partial \mu_{i}}=\sum_{x\in C_{i}}\frac{\partial d(x,\mu _{i})}{\partial \mu_{i}}
如果我们用欧式距离的平方作为度量标准,即d(x,\mu)=\left | \left | x-\mu \right | \right | ^2
,上式可以进一步计算为:
\frac{\partial L}{\partial \mu_{i}}
=\sum_{x\in C_{i}}\frac{\partial \left | \left | x-\mu \right | \right | ^2}{\partial \mu_{i}}
=2\sum_{x\in C_{i}}(x-\mu_{i})
=2\sum_{x\in C_{i}}x-2\left | C_{i} \right | \mu_{i}
令该偏导数为零,就得到最优的中心点:
\mu_{i}=\frac{1}{\left|C_{i}\right|}\sum_{x\in C_{i}}x
上式表明,最优中心点就是C_{i}
中所有点的质心。但是,当中心点更新后,每个样本到最近的中心点的距离可能也会发生变化。因此,我们重新计算每个样本点到中心点的距离,对它们重新分类,再计算新的质心。如此反复迭代,直到各个点的分类几乎不再变化或者达到预设的迭代次数为止。
注意,如果采用其他的距离函数作为度量,那么最优的中心点就不是集合的质心。
实践
先看一下数据集kmeans_data.csv长啥样:
一共80条数据,每行包含两个值x_{1}
和x_{2}
,表示平面上坐标为(x_{1},x_{2})
的点,考虑我们还希望绘制迭代的中间步骤,这里将绘图部分写成一个函数。
import numpy as np
import matplotlib.pyplot as plt
dataset = np.loadtxt('kmeans_data.csv', delimiter=',')
print('数据集大小:', len(dataset))
数据集大小: 80
# 绘图函数
def show_cluster(dataset, cluster, centroids=None):
# dataset:数据
# centroids:聚类中心点的坐标
# cluster:每个样本所属聚类
# 不同种类的颜色,用以区分划分的数据的类别
colors = ['blue', 'red', 'green', 'purple']
markers = ['o', '^', 's', 'd']
# 画出所有样例
K = len(np.unique(cluster))
for i in range(K):
plt.scatter(dataset[cluster == i, 0], dataset[cluster == i, 1],
color=colors[i], marker=markers[i])
# 画出中心点
if centroids is not None:
plt.scatter(centroids[:, 0], centroids[:, 1],
color=colors[:K], marker='+', s=150)
plt.show()
# 初始时不区分类别
show_cluster(dataset, np.zeros(len(dataset), dtype=int))
对于简单的k均值算法,初始的中心点是从现有样本中随机选取的,如下:
def random_init(dataset, K):
# 随机选取是不重复的
idx = np.random.choice(np.arange(len(dataset)), size=K, replace=False)
return dataset[idx]
下面用欧氏距离作为度量标准,进行中心点的迭代。由于数据集比较简单,我们迭代的终止条件设置为所有样本的分类都不再变化。对于更复杂的数据集,这一条件很可能使迭代无法终止,需要人为控制最大迭代次数,或者设置允许类别变动的样本的比例等。
def Kmeans(dataset, K, init_cent):
# dataset:数据集
# K:目标聚类数
# init_cent:初始化中心点的函数
centroids = init_cent(dataset, K)
cluster = np.zeros(len(dataset), dtype=int)
changed = True
# 开始迭代
itr = 0
while changed:
changed = False
loss = 0
for i, data in enumerate(dataset):
# 寻找最近的中心点
dis = np.sum((centroids - data) ** 2, axis=-1)
k = np.argmin(dis)
# 更新当前样本所属的聚类
if cluster[i] != k:
cluster[i] = k
changed = True
# 计算损失函数
loss += np.sum((data - centroids[k]) ** 2)
# 绘图
print(f'Iteration {itr}, Loss {loss:.3f}')
show_cluster(dataset, cluster, centroids)
# 更新中心点
for i in range(K):
centroids[i] = np.mean(dataset[cluster == i], axis=0)
itr += 1
return centroids, cluster
最后,观察k均值算法在数据集上聚类的过程。根据前面的可视化结果,大概可以看出有4个聚类,因此设定K=4
np.random.seed(0)
cent, cluster = Kmeans(dataset, 4, random_init)
Iteration 0, Loss 711.336
Iteration 1, Loss 409.495
Iteration 2, Loss 395.264
Iteration 3, Loss 346.068
Iteration 4, Loss 294.244
Iteration 5, Loss 178.808
Iteration 6, Loss 151.090
k-means++算法
(待补充)
def kmeanspp_init(dataset, K):
# 随机第一个中心点
idx = np.random.choice(np.arange(len(dataset)))
centroids = dataset[idx][None]
for k in range(1, K):
d = []
# 计算每个点到当前中心点的距离
for data in dataset:
dis = np.sum((centroids - data) ** 2, axis=-1)
# 取最短距离的平方
d.append(np.min(dis) ** 2)
# 归一化
d = np.array(d)
d /= np.sum(d)
# 按概率选取下一个中心点
cent_id = np.random.choice(np.arange(len(dataset)), p=d)
cent = dataset[cent_id]
centroids = np.concatenate([centroids, cent[None]], axis=0)
return centroids
cent, cluster = Kmeans(dataset, 4, kmeanspp_init)
Iteration 0, Loss 373.939
Iteration 1, Loss 158.147
Iteration 2, Loss 151.273