KNN算法实战:从距离度量到鸢尾花分类

KNN算法实战:从距离度量到鸢尾花分类 1. KNN算法基础与距离度量K最近邻K-Nearest Neighbors简称KNN是机器学习中最简单直观的分类算法之一。它的核心思想可以用一句老话来概括近朱者赤近墨者黑。算法通过计算待分类样本与训练集中各个样本的距离找出距离最近的K个邻居然后根据这些邻居的类别投票决定待分类样本的类别。1.1 常见距离公式详解在实际应用中我们主要使用以下几种距离度量方式欧氏距离Euclidean Distance这是最直观的距离度量方式就是我们日常生活中理解的两点之间的直线距离。数学表达式为$$ d_{12} \sqrt{\sum_{k1}^{n} \left( X_{2k} - X_{1k} \right)^2} $$适用场景当特征之间是连续数值型且量纲相同时欧氏距离表现最好。比如在鸢尾花分类中花瓣长度和宽度的测量单位都是厘米就非常适合使用欧氏距离。曼哈顿距离Manhattan Distance又称城市街区距离得名于在曼哈顿街区行走的路径。数学表达式为$$ d_{12} \sum_{k1}^{n} \left| X_{2k} - X_{1k} \right| $$特点与适用场景只能沿着坐标轴方向移动不能斜穿对异常值比欧氏距离更鲁棒适用于高维数据或特征之间存在相关性时切比雪夫距离Chebyshev Distance定义为各坐标数值差的最大值$$ d_{12} \max_k \left| X_{2k} - X_{1k} \right| $$适用场景在棋盘游戏中特别有用比如计算国王从一个格子走到另一个格子的最少步数。闵可夫斯基距离Minkowski Distance这是上述距离的一般化形式$$ d_{12} \left( \sum_{k1}^{n} \left| X_{2k} - X_{1k} \right|^p \right)^{1/p} $$当p1时就是曼哈顿距离p2时就是欧氏距离p→∞时趋近于切比雪夫距离。距离选择小贴士在实际应用中欧氏距离是最常用的默认选择。但如果特征量纲差异大应先进行标准化处理。对于高维数据可以尝试曼哈顿距离或余弦相似度。1.2 距离度量的代码实现import numpy as np # 欧氏距离 def euclidean_distance(a, b): return np.sqrt(np.sum((a - b)**2)) # 曼哈顿距离 def manhattan_distance(a, b): return np.sum(np.abs(a - b)) # 切比雪夫距离 def chebyshev_distance(a, b): return np.max(np.abs(a - b)) # 测试 a np.array([1, 2, 3]) b np.array([4, 5, 6]) print(f欧氏距离: {euclidean_distance(a, b):.2f}) print(f曼哈顿距离: {manhattan_distance(a, b)}) print(f切比雪夫距离: {chebyshev_distance(a, b)})2. 数据预处理与特征工程2.1 归一化处理归一化Normalization是将数据按比例缩放到一个特定区间通常是[0,1]的过程。计算公式为$$ X \frac{x - \min}{\max - \min} $$特点对最大值和最小值敏感容易受异常值影响适用于数据分布有明显边界的情况常用于小数据集或需要保持数据原始分布的场景实战代码from sklearn.preprocessing import MinMaxScaler import numpy as np # 创建归一化对象可指定范围默认[0,1] scaler MinMaxScaler(feature_range(0, 1)) # 示例数据 data np.array([[90, 2, 10, 40], [60, 4, 15, 45], [75, 3, 13, 46]]) # 拟合并转换数据 normalized_data scaler.fit_transform(data) print(归一化后的数据) print(normalized_data)2.2 标准化处理标准化Standardization是将数据转换为均值为0标准差为1的分布。计算公式为$$ X \frac{x - \mu}{\sigma} $$其中μ是均值σ是标准差。特点对异常值不敏感适用于数据分布没有明显边界或存在异常值的情况更适合大数据集是机器学习中最常用的预处理方法实战代码from sklearn.preprocessing import StandardScaler # 创建标准化对象 scaler StandardScaler() # 示例数据 data [[90, 2, 10, 40], [60, 4, 15, 45], [75, 3, 13, 46]] # 拟合并转换数据 standardized_data scaler.fit_transform(data) print(标准化后的数据) print(standardized_data) # 查看转换参数 print(f均值: {scaler.mean_}) print(f方差: {scaler.var_})2.3 归一化与标准化的选择指南特性归一化标准化计算方法(x-min)/(max-min)(x-μ)/σ结果范围[0,1]或自定义理论上无界实际多在[-3,3]对异常值敏感不敏感适用场景小数据集有明显边界大数据集可能有异常值算法适用性图像处理神经网络大多数机器学习算法经验法则当不确定用哪种时优先选择标准化。特别是使用距离度量的算法如KNN、K-Means时标准化通常能获得更好的效果。3. 鸢尾花分类实战3.1 数据集探索鸢尾花数据集是机器学习中最经典的数据集之一包含三种鸢尾花Setosa、Versicolour、Virginica的各50个样本每个样本有4个特征花萼长度、花萼宽度、花瓣长度、花瓣宽度单位均为厘米。数据集加载与探索from sklearn.datasets import load_iris import pandas as pd # 加载数据集 iris load_iris() # 转换为DataFrame方便查看 iris_df pd.DataFrame(iris.data, columnsiris.feature_names) iris_df[target] iris.target iris_df[species] iris_df[target].apply(lambda x: iris.target_names[x]) print(iris_df.head()) print(\n数据集描述:) print(iris_df.describe())3.2 数据可视化可视化是理解数据的重要步骤可以帮助我们发现数据分布和潜在模式。import seaborn as sns import matplotlib.pyplot as plt # 特征间关系矩阵图 sns.pairplot(iris_df, huespecies, palettehusl) plt.suptitle(鸢尾花特征关系矩阵, y1.02) plt.show() # 箱线图查看特征分布 plt.figure(figsize(12, 6)) sns.boxplot(datairis_df.drop([target], axis1), paletteSet2) plt.title(特征值分布箱线图) plt.xticks(rotation45) plt.show()3.3 完整建模流程一个完整的机器学习项目通常包含以下步骤数据加载与探索数据预处理数据集划分模型训练模型评估模型应用完整代码示例from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler from sklearn.neighbors import KNeighborsClassifier from sklearn.metrics import classification_report, confusion_matrix import seaborn as sns # 1. 数据准备 X iris.data y iris.target # 2. 数据划分 X_train, X_test, y_train, y_test train_test_split( X, y, test_size0.2, random_state42, stratifyy) # 3. 特征工程 - 标准化 scaler StandardScaler() X_train_scaled scaler.fit_transform(X_train) X_test_scaled scaler.transform(X_test) # 注意使用训练集的参数转换测试集 # 4. 模型训练 knn KNeighborsClassifier(n_neighbors3) knn.fit(X_train_scaled, y_train) # 5. 模型评估 y_pred knn.predict(X_test_scaled) print(分类报告:) print(classification_report(y_test, y_pred, target_namesiris.target_names)) # 混淆矩阵可视化 cm confusion_matrix(y_test, y_pred) plt.figure(figsize(8, 6)) sns.heatmap(cm, annotTrue, fmtd, cmapBlues, xticklabelsiris.target_names, yticklabelsiris.target_names) plt.xlabel(预测标签) plt.ylabel(真实标签) plt.title(混淆矩阵) plt.show()3.4 模型优化与超参数调优KNN算法中最重要的超参数是K值的选择。我们可以使用网格搜索结合交叉验证来寻找最优K值。from sklearn.model_selection import GridSearchCV # 定义参数网格 param_grid {n_neighbors: range(1, 15)} # 创建GridSearchCV对象 grid_search GridSearchCV(KNeighborsClassifier(), param_grid, cv5, # 5折交叉验证 return_train_scoreTrue) # 执行网格搜索 grid_search.fit(X_train_scaled, y_train) # 输出最优参数 print(f最佳K值: {grid_search.best_params_[n_neighbors]}) print(f最佳交叉验证分数: {grid_search.best_score_:.3f}) # 可视化不同K值的表现 results pd.DataFrame(grid_search.cv_results_) plt.figure(figsize(10, 6)) plt.plot(param_grid[n_neighbors], results[mean_train_score], label训练分数) plt.plot(param_grid[n_neighbors], results[mean_test_score], label验证分数) plt.xlabel(K值) plt.ylabel(准确率) plt.title(K值选择与模型表现) plt.legend() plt.grid() plt.show()超参数调优经验K值太小会导致模型过拟合太大则会导致欠拟合。通常从K3或5开始尝试通过交叉验证确定最佳值。在鸢尾花数据集中K3或5通常表现最佳。4. 手写数字识别进阶实战4.1 数据集理解与准备MNIST手写数字数据集包含60,000个训练样本和10,000个测试样本每个样本是28x28像素的灰度图像像素值范围0-255。数据加载与探索import pandas as pd import matplotlib.pyplot as plt from collections import Counter # 加载数据 digits pd.read_csv(data_numberCV/手写数字识别.csv) # 查看数据分布 print(f样本总数: {len(digits)}) print(数字分布情况:) print(Counter(digits.iloc[:, 0])) # 可视化部分样本 plt.figure(figsize(10, 8)) for i in range(25): plt.subplot(5, 5, i1) plt.imshow(digits.iloc[i, 1:].values.reshape(28, 28), cmapgray) plt.title(fLabel: {digits.iloc[i, 0]}) plt.axis(off) plt.tight_layout() plt.show()4.2 数据预处理关键步骤手写数字识别需要特殊的数据预处理归一化将像素值从0-255缩放到0-1之间有助于模型收敛数据集划分保持各类别比例一致stratify维度处理将28x28图像展平为784维向量from sklearn.model_selection import train_test_split # 分离特征和标签 X digits.iloc[:, 1:].values y digits.iloc[:, 0].values # 归一化 X X / 255.0 # 数据集划分 X_train, X_test, y_train, y_test train_test_split( X, y, test_size0.2, random_state42, stratifyy) print(f训练集大小: {X_train.shape}) print(f测试集大小: {X_test.shape})4.3 模型训练与评估from sklearn.neighbors import KNeighborsClassifier from sklearn.metrics import accuracy_score, confusion_matrix import seaborn as sns import joblib # 初始化KNN模型 knn KNeighborsClassifier(n_neighbors3) # 训练模型 knn.fit(X_train, y_train) # 评估模型 y_pred knn.predict(X_test) accuracy accuracy_score(y_test, y_pred) print(f测试集准确率: {accuracy:.4f}) # 混淆矩阵 cm confusion_matrix(y_test, y_pred) plt.figure(figsize(10, 8)) sns.heatmap(cm, annotTrue, fmtd, cmapBlues) plt.xlabel(预测标签) plt.ylabel(真实标签) plt.title(混淆矩阵) plt.show() # 保存模型 joblib.dump(knn, digit_recognizer.pkl) print(模型保存成功!)4.4 模型应用与预测训练好的模型可以用于识别新的手写数字图像import cv2 import numpy as np def predict_digit(image_path, model_pathdigit_recognizer.pkl): # 加载模型 model joblib.load(model_path) # 读取并预处理图像 img cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) img cv2.resize(img, (28, 28)) # 确保尺寸正确 img 255 - img # 反色处理MNIST是白底黑字 img img / 255.0 # 归一化 # 展平图像并预测 img_flat img.reshape(1, -1) prediction model.predict(img_flat) probabilities model.predict_proba(img_flat) # 可视化 plt.imshow(img, cmapgray) plt.title(f预测数字: {prediction[0]}) plt.axis(off) plt.show() print(f预测数字: {prediction[0]}) print(各类别概率:) for i, prob in enumerate(probabilities[0]): print(f数字 {i}: {prob:.4f}) return prediction[0] # 使用示例 predict_digit(data_numberCV/demo.png)4.5 性能优化技巧PCA降维784维可能过高可以尝试PCA降低维度K值调优使用网格搜索寻找最佳K值距离度量尝试不同的距离度量如曼哈顿距离数据增强对训练数据进行旋转、平移等变换增加样本多样性from sklearn.decomposition import PCA # PCA降维 pca PCA(n_components0.95) # 保留95%的方差 X_train_pca pca.fit_transform(X_train) X_test_pca pca.transform(X_test) print(f原始维度: {X_train.shape[1]}) print(f降维后维度: {X_train_pca.shape[1]}) # 使用降维后的数据训练模型 knn_pca KNeighborsClassifier(n_neighbors3) knn_pca.fit(X_train_pca, y_train) accuracy_pca knn_pca.score(X_test_pca, y_test) print(f降维后测试准确率: {accuracy_pca:.4f})5. 项目总结与经验分享5.1 KNN算法的优缺点总结优点原理简单易于理解和实现无需训练过程适合增量学习对数据分布没有假设在多分类问题上表现良好缺点计算复杂度高需要存储全部训练数据对高维数据效果不佳维度灾难对不平衡数据敏感需要合适的距离度量和K值选择5.2 实战中的经验教训数据预处理至关重要忘记标准化会导致距离计算被大尺度特征主导图像数据必须进行适当的归一化和维度处理K值选择的技巧通常选择奇数K值以避免平票可以使用交叉验证确定最佳K值对于大型数据集K值可以适当增大距离度量的选择欧氏距离是默认选择对于高维数据余弦相似度可能更合适可以尝试不同的距离度量并比较效果计算效率优化对于大型数据集考虑使用KD树或球树数据结构可以采样部分数据作为代表点使用PCA等降维技术减少特征维度5.3 项目扩展思路尝试其他分类算法与SVM、随机森林等算法比较效果尝试深度学习模型如CNN在手写数字识别上的表现开发完整应用构建Web应用允许用户上传手写数字图片进行识别开发移动应用实时识别摄像头拍摄的数字探索更复杂的数据集尝试Fashion-MNIST等更复杂的数据集收集自己的手写数字数据集进行训练算法优化方向实现加权KNN让更近的邻居有更大投票权尝试自适应确定K值的算法在实际项目中KNN虽然简单但往往能提供不错的基线性能。理解其原理和局限性能帮助我们在合适的场景中有效应用这一算法。