K-Means案例实际讲解,适合大学生突击期末

K-Means案例实际讲解,适合大学生突击期末 太棒了既然你兴致这么高我们就把 sklearn 的拐杖扔掉用刚才学过的 NumPy 核心知识广播、向量化、布尔索引手搓一个简易版的 K-Means 算法你会发现剥去框架的外衣K-Means 的核心逻辑其实只有 3 步而且用 NumPy 写出来极其优雅。K-Means 的核心思想一句话“找几个中心点把离谁近的点归给谁然后重新算中心点直到中心点不再移动。”完整代码实现带详细注释import numpy as npclass SimpleKMeans:definit(self, n_clusters3, max_iters100):self.n_clusters n_clustersself.max_iters max_itersdef fit(self, X): # X 的形状是 (样本数, 特征数)比如鸢尾花是 (150, 2) n_samples, n_features X.shape # 【初始化】随机挑选 n_clusters 个点作为初始中心 # 这里用整数列表索引花式索引随机抽取 random_indices np.random.choice(n_samples, self.n_clusters, replaceFalse) self.centroids X[random_indices, :] # 【迭代开始】 for _ in range(self.max_iters): # 第 1 步计算距离并分配簇 # 利用广播计算每个点到每个中心的距离 # X 形状: (150, 2), centroids 形状: (3, 2) # 我们想让 X 减去每一个中心得到 (150, 3, 2) 的三维数组 # 技巧把 centroids 变成 (1, 3, 2)X 变成 (150, 1, 2) distances np.sqrt(((X[:, np.newaxis, :] - self.centroids[np.newaxis, :, :]) ** 2).sum(axis2)) # 找出每个点距离最近的那个中心的索引形状变为 150 labels np.argmin(distances, axis1) # 第 2 步重新计算中心点 new_centroids np.zeros((self.n_clusters, n_features)) for k in range(self.n_clusters): # 【布尔索引】把属于第 k 类的点全部挑出来求平均 cluster_points X[labels k] # 如果某个簇没有分到点保持原中心不变 if len(cluster_points) 0: new_centroids[k] cluster_points.mean(axis0) # 第 3 步判断是否收敛 # 如果新旧中心点的距离小于极小值说明不再移动提前结束 if np.allclose(self.centroids, new_centroids): print(f算法在第 {_} 次迭代时收敛) break self.centroids new_centroids return labels, self.centroids代码里的“高光时刻”解析重点看距离计算的“降维打击”distances np.sqrt(((X[:, np.newaxis, :] - self.centroids[np.newaxis, :, :]) ** 2).sum(axis2))这一行是整个算法最核心的向量化操作X[:, np.newaxis, :] 把形状从 (150, 2) 变成了 (150, 1, 2)。centroids[np.newaxis, :, :] 把形状从 (3, 2) 变成了 (1, 3, 2)。两者相减时NumPy 触发终极广播直接生成一个 (150, 3, 2) 的立体矩阵。.sum(axis2) 沿着特征维度求和得到 (150, 3) 的距离矩阵。np.argmin(axis1) 直接找出每一行每个样本距离最小的中心点编号。没有写一行 for 循环150 个点的距离瞬间算完更新中心点的“布尔索引”cluster_points X[labels k]new_centroids[k] cluster_points.mean(axis0)labels k 生成一个布尔掩码Mask。直接通过掩码把属于第 k 类的样本捞出来。mean(axis0) 沿着行样本的方向求平均瞬间算出新的中心点。怎么跑起来造一点假数据测试一下np.random.seed(42)生成 3 个簇的假数据cluster_1 np.random.randn(50, 2) [2, 2]cluster_2 np.random.randn(50, 2) [-2, -2]cluster_3 np.random.randn(50, 2) [2, -2]fake_data np.vstack([cluster_1, cluster_2, cluster_3])运行我们的手搓算法kmeans SimpleKMeans(n_clusters3)labels, centroids kmeans.fit(fake_data)print(“找到的中心点n”, centroids)学习小结到这里你不仅学会了K-Means 的底层原理分配 - 更新 - 收敛。NumPy 的进阶玩法np.newaxis 制造维度、三维数组广播、布尔索引。你现在已经具备了阅读和理解大部分传统机器学习算法源码的能力了感觉怎么样是不是有种打通任督二脉的爽快感接下来你是想看看怎么把这个结果画成漂亮的散点图还是想了解一下K-Means 算法的致命缺陷比如怎么确定 K 值