别再死记公式了!用Python手撸一个LDA分类器(附完整代码与数据集)

别再死记公式了!用Python手撸一个LDA分类器(附完整代码与数据集) 从数学推导到代码实现用Python手写LDA分类器的实战指南很多机器学习初学者在学习线性判别分析LDA时往往能理解其数学原理但当真正要用代码实现时却无从下手。本文将带你用NumPy从零开始构建一个LDA分类器并在鸢尾花数据集上进行验证。不同于简单地调用scikit-learn我们将深入算法核心一步步实现类内散度矩阵、类间散度矩阵的计算以及特征值分解等关键步骤。1. LDA算法核心原理回顾LDA的核心思想是通过线性变换将高维数据投影到低维空间使得同类样本尽可能接近不同类样本尽可能远离。这种投影本质上是在寻找一个最优的超平面。对于二分类问题LDA需要计算两个关键矩阵类内散度矩阵(Sw)衡量同一类别样本的分散程度类间散度矩阵(Sb)衡量不同类别样本中心的距离最终的投影方向w通过求解广义特征值问题得到Sw⁻¹ * Sb * w λ * w注意在实际计算中我们通常选择最大特征值对应的特征向量作为投影方向。2. 数据准备与预处理我们将使用经典的鸢尾花数据集作为示例。这个数据集包含150个样本每个样本有4个特征花萼长度、花萼宽度、花瓣长度、花瓣宽度和1个类别标签Setosa、Versicolor、Virginica。from sklearn.datasets import load_iris import numpy as np # 加载数据集 iris load_iris() X iris.data y iris.target # 为了简化问题我们只使用前两个类别和两个特征 X X[y ! 2, :2] y y[y ! 2] # 数据标准化 X (X - np.mean(X, axis0)) / np.std(X, axis0)3. 核心算法实现3.1 计算类内散度矩阵Sw类内散度矩阵是各个类别协方差矩阵的和。对于二分类问题def compute_within_class_scatter(X, y): # 获取类别数量 classes np.unique(y) n_features X.shape[1] # 初始化Sw矩阵 Sw np.zeros((n_features, n_features)) for c in classes: # 获取当前类别的样本 X_c X[y c] # 计算类别均值 mean_c np.mean(X_c, axis0) # 计算协方差矩阵 cov_c (X_c - mean_c).T (X_c - mean_c) # 累加到Sw Sw cov_c return Sw3.2 计算类间散度矩阵Sb类间散度矩阵衡量的是不同类别均值之间的差异def compute_between_class_scatter(X, y): # 获取全局均值 mean_total np.mean(X, axis0) # 获取类别信息 classes np.unique(y) n_features X.shape[1] # 初始化Sb矩阵 Sb np.zeros((n_features, n_features)) for c in classes: # 获取当前类别的样本数和均值 X_c X[y c] n_c X_c.shape[0] mean_c np.mean(X_c, axis0) # 计算类间贡献 diff mean_c - mean_total Sb n_c * np.outer(diff, diff) return Sb3.3 求解投影方向有了Sw和Sb后我们需要求解广义特征值问题def compute_lda_projection(X, y): # 计算Sw和Sb Sw compute_within_class_scatter(X, y) Sb compute_between_class_scatter(X, y) # 计算Sw的逆 Sw_inv np.linalg.inv(Sw) # 计算Sw⁻¹ * Sb matrix Sw_inv Sb # 计算特征值和特征向量 eigenvalues, eigenvectors np.linalg.eig(matrix) # 选择最大特征值对应的特征向量 max_idx np.argmax(eigenvalues) w eigenvectors[:, max_idx] return w4. 结果验证与可视化让我们将自实现的LDA与scikit-learn的LDA结果进行对比from sklearn.discriminant_analysis import LinearDiscriminantAnalysis import matplotlib.pyplot as plt # 自实现LDA w_manual compute_lda_projection(X, y) # scikit-learn LDA lda LinearDiscriminantAnalysis() lda.fit(X, y) w_sklearn lda.scalings_[:, 0] # 归一化比较 w_manual w_manual / np.linalg.norm(w_manual) w_sklearn w_sklearn / np.linalg.norm(w_sklearn) print(手动实现投影方向:, w_manual) print(Scikit-learn投影方向:, w_sklearn) # 可视化投影结果 def plot_projection(X, y, w, title): # 计算投影 projection X w plt.figure() for c in np.unique(y): plt.hist(projection[y c], alpha0.5, labelfClass {c}) plt.title(title) plt.legend() plot_projection(X, y, w_manual, Manual LDA Projection) plot_projection(X, y, w_sklearn, Scikit-learn LDA Projection) plt.show()5. 实际应用中的注意事项在实现LDA时有几个关键点需要特别注意矩阵求逆的稳定性当Sw接近奇异矩阵时直接求逆可能导致数值不稳定解决方案是使用伪逆或加入小的正则项Sw_reg Sw 1e-6 * np.eye(Sw.shape[0])多分类问题的扩展对于多分类问题LDA可以生成最多C-1个判别方向C是类别数需要选择前k个最大特征值对应的特征向量特征选择的影响LDA假设数据服从高斯分布且各类协方差矩阵相同如果这些假设不成立分类性能可能会下降与PCA的区别PCA是无监督的寻找最大方差方向LDA是有监督的寻找最佳分类方向6. 性能优化技巧为了提高LDA实现的效率和稳定性可以考虑以下优化批量矩阵运算利用NumPy的广播机制避免循环内存优化对于高维数据使用SVD代替直接求逆并行计算对于多分类问题可以并行计算各类统计量# 优化后的Sw计算 def compute_within_class_scatter_optimized(X, y): classes np.unique(y) means [np.mean(X[y c], axis0) for c in classes] # 使用广播一次性计算所有类别的协方差 centered [X[y c] - means[i] for i, c in enumerate(classes)] covs [c.T c for c in centered] return np.sum(covs, axis0)7. 扩展应用LDA在文本分类中的实践虽然我们以鸢尾花数据集为例但LDA在文本分类中也有广泛应用。例如在新闻分类中将文本转换为TF-IDF特征向量计算类内和类间散度矩阵降维后使用简单分类器如朴素贝叶斯这种方法的优势在于有效降低特征维度保留最具判别性的特征计算效率高适合大规模文本数据在实现文本分类的LDA时需要注意特征矩阵通常是稀疏的可以使用scipy.sparse中的专用函数来提高计算效率。