内容目录
k近邻算法主要应用于在分类上的有监督算法中。
原理
在分类任务中,我们的目标是判断样本x的类别y。
KNN会先观察与该样本点距离最近的K个样本,统计这些样本所属的类别。然后,将当前样本归到出现次数最多的类中。我们用KNN算法的经典示意图来更清晰地说明其思想
如下图:
假设共有两个类别的数据点,圆形和正方形,而中心位置的样本×当前尚未被分类,根据统计近邻的思路:
- 当K=3时,样本×的3个邻居中有2个正方形样本,1个圆形样本,因此应该将样本×归类为正方形。
- 当K=5时,样本×的5个邻居中有2个正方形样本,3个圆形样本,因此应该将样本×归类为圆形。
从这个例子可以看出,KNN的基本思路是让当前的样本的分类服从邻居中的多数分类。但是,当K的大小变化时,由于邻居的数量变化,其多数类别也可能会变化,从而改变对当前样本的分类判断。因此,决定K的大小是KNN中最重要的部分之一。直观上来说,当K的取值太小时,分类结果很容易受到待分类样本周围的个别噪声数据影响;当K的取值太大时,又可能将远处一些不相关的样本包含进来。因此,我们应该根据数据集动态地调整K的大小,以得到最理想的结果。
使用KNN分类
MNIST是手写数字数据集(mnist官网),其中包含了很多手写数字0~9的黑白图像,每幅图像的尺寸是28×28像素。每个像素用1或0表示,1代表黑色像素,属于图像背景;0代表白色像素,属于手写数字。
我们的任务是用KNN对不同的手写数字进行分类。为了更清晰地展示数据集的内容,下面先将一个数据点转化成黑白图像显示出来。此外,把每个数据集都按8:2的比例随机划分成训练集(training set)和测试集(test set),我们先在训练集上应用KNN算法,再在测试集上测试算法的表现。
import matplotlib.pyplot as plt
import numpy as np
import os
# 读入mnist数据集
m_x = np.loadtxt('mnist_x', delimiter=' ')
m_y = np.loadtxt('mnist_y')
# 数据集可视化
data = np.reshape(np.array(m_x[0], dtype=int), [28, 28])
plt.figure()
plt.imshow(data, cmap='gray')
# 将数据集分为训练集和测试集
ratio = 0.8
split = int(len(m_x) * ratio)
# 打乱数据,用np.random.permutation
np.random.seed(0)
idx = np.random.permutation(np.arange(len(m_x)))
m_x = m_x[idx]
m_y = m_y[idx]
x_train, x_test = m_x[:split], m_x[split:]
y_train, y_test = m_y[:split], m_y[split:]
这里采用最常用的欧氏距离
def distance(a, b):
return np.sqrt(np.sum(np.square(a - b)))
class KNN:
def __init__(self, k, label_num):
self.k = k
self.label_num = label_num # 类别的数量
def fit(self, x_train, y_train):
# 在类中保存训练数据
self.x_train = x_train
self.y_train = y_train
def get_knn_indices(self, x):
# 获取距离目标样本点最近的K个样本点的标签
# 计算已知样本的距离
dis = list(map(lambda a: distance(a, x), self.x_train))
# 按距离从小到大排序,并得到对应的下标
knn_indices = np.argsort(dis)
# 取最近的K个
knn_indices = knn_indices[:self.k]
return knn_indices
def get_label(self, x):
# 对KNN方法的具体实现,观察K个近邻并使用np.argmax获取其中数量最多的类别
knn_indices = self.get_knn_indices(x)
# 类别计数
label_statistic = np.zeros(shape=[self.label_num])
for index in knn_indices:
label = int(self.y_train[index])
label_statistic[label] += 1
# 返回数量最多的类别
return np.argmax(label_statistic)
def predict(self, x_test):
# 预测样本 test_x 的类别
predicted_test_labels = np.zeros(shape=[len(x_test)], dtype=int)
for i, x in enumerate(x_test):
predicted_test_labels[i] = self.get_label(x)
return predicted_test_labels
最后,在测试集上观察算法的效果,并对不同的K的取值进行测试。
for k in range(1, 10):
knn = KNN(k, label_num=10)
knn.fit(x_train, y_train)
predicted_labels = knn.predict(x_test)
accuracy = np.mean(predicted_labels == y_test)
print(f'K的取值为 {k}, 预测准确率为 {accuracy * 100:.1f}%')
K的取值为 1, 预测准确率为 88.5%
K的取值为 2, 预测准确率为 88.0%
K的取值为 3, 预测准确率为 87.5%
K的取值为 4, 预测准确率为 87.5%
K的取值为 5, 预测准确率为 88.5%
K的取值为 6, 预测准确率为 88.5%
K的取值为 7, 预测准确率为 88.0%
K的取值为 8, 预测准确率为 87.0%
K的取值为 9, 预测准确率为 87.0%
使用scikit-learn实现KNN算法
from sklearn.neighbors import KNeighborsClassifier # sklearn中的KNN分类器
from matplotlib.colors import ListedColormap
# 读入高斯数据集
data = np.loadtxt('gauss.csv', delimiter=',')
x_train = data[:, :2]
y_train = data[:, 2]
print('数据集大小:', len(x_train))
# 可视化
plt.figure()
plt.scatter(x_train[y_train == 0, 0], x_train[y_train == 0, 1], c='blue', marker='o')
plt.scatter(x_train[y_train == 1, 0], x_train[y_train == 1, 1], c='red', marker='x')
plt.xlabel('X axis')
plt.ylabel('Y axis')
plt.show()
数据集大小: 200
# 设置步长
step = 0.02
# 设置网格边界
x_min, x_max = np.min(x_train[:, 0]) - 1, np.max(x_train[:, 0]) + 1
y_min, y_max = np.min(x_train[:, 1]) - 1, np.max(x_train[:, 1]) + 1
# 构造网格
xx, yy = np.meshgrid(np.arange(x_min, x_max, step), np.arange(y_min, y_max, step))
grid_data = np.concatenate([xx.reshape(-1, 1), yy.reshape(-1, 1)], axis=1)
fig = plt.figure(figsize=(16,4.5))
# K值,读者可以自行调整,观察分类结果的变化
ks = [1, 3, 10]
cmap_light = ListedColormap(['royalblue', 'lightcoral'])
for i, k in enumerate(ks):
# 定义KNN分类器
knn = KNeighborsClassifier(n_neighbors=k)
knn.fit(x_train, y_train)
z = knn.predict(grid_data)
# 画出分类结果
ax = fig.add_subplot(1, 3, i + 1)
ax.pcolormesh(xx, yy, z.reshape(xx.shape), cmap=cmap_light, alpha=0.7)
ax.scatter(x_train[y_train == 0, 0], x_train[y_train == 0, 1], c='blue', marker='o')
ax.scatter(x_train[y_train == 1, 0], x_train[y_train == 1, 1], c='red', marker='x')
ax.set_xlabel('X axis')
ax.set_ylabel('Y axis')
ax.set_title(f'K = {k}')
plt.show()