基于MATLAB的高效线性判别分析(LDA)分类算法
项目简介
本项目是一个完全基于MATLAB编写的线性判别分析(Linear Discriminant Analysis, LDA)算法实现。该程序基于Fisher准则,旨在通过寻找最佳投影方向,最大化类间散度(Between-class scatter)并最小化类内散度(Within-class scatter),从而实现高维数据的有效降维与分类。
项目不仅仅是一个算法库,还包含了一个完整的端到端演示流程:从高维合成数据的生成、数据预处理、LDA模型训练、低维空间投影到最终的最小距离分类器预测及结果可视化。该实现经过逻辑优化,特别处理了协方差矩阵可能存在的奇异性问题,适用于机器学习教学、模式识别研究以及数据降维的特征提取阶段。
核心功能特性
- 自动化高维数据模拟:内置数据生成模块,能够自动生成符合多变量高斯分布的合成数据集(默认5维特征,3个类别),用于验证算法有效性。
- 鲁棒的Fisher LDA实现:
* 自动计算各类均值向量及全局均值。
* 构建类内散布矩阵($S_w$)和类间散布矩阵($S_b$)。
*
正则化处理:在计算过程中加入了微小的正则化项,防止因矩阵奇异(Singular Matrix)导致的数值不稳定或不可逆问题。
*
广义特征值求解:直接求解 $S_b w = lambda S_w w$,提取最具判别力的特征向量。
- 最小距离分类器:在降维后的LDA子空间内,通过计算测试样本与各类别中心的欧氏距离进行分类,计算效率高。
- 全流程评估与可视化:
* 自动划分训练集与测试集(默认7:3比例)。
* 计算并输出分类准确率。
* 双视图可视化对比:展示原始高维空间(前两维)与LDA投影空间的样本分布差异,并直观标记分类错误的样本。
系统要求
- MATLAB R2016a 或更高版本(代码中使用了
mvnrnd 等统计工具箱函数)。 - Statistics and Machine Learning Toolbox。
快速开始
- 确保MATLAB环境已准备就绪。
- 直接运行
main.m 脚本(或在该文件所在的目录下执行 main 命令)。 - 程序将自动执行数据生成、模型训练与评估。
- 控制台将输出训练耗时、数据集信息及最终的分类准确率。
- 系统将弹出一个包含两个子图的结果窗口,展示降维前后的数据分布对比。
算法原理与代码实现细节
本项目的主要逻辑封装在 main 函数及其调用的三个辅助子函数中,具体实现逻辑如下:
1. 数据准备与预处理
程序首先通过设定随机种子(
rng)保证实验结果的可复现性。随后生成3类服从不同均值和协方差的高斯分布数据,总样本量为400个,特征维度为5维。数据生成后被合并并随机打乱,按70%训练集、30%测试集的比例进行划分。
2. 模型训练 (Fisher LDA 核心)
训练过程由核心算法函数完成,逻辑如下:
- 统计量计算:遍历训练数据,分别计算全局均值和每个类别的局部均值。
- 散布矩阵构建:
*
类间散布矩阵 ($S_b$):计算各类均值与全局均值之差的加权外积。
*
类内散布矩阵 ($S_w$):计算各样本与其所属类均值偏差的外积之和。
- 正则化策略:为了保证 $S_w$ 的可逆性及数值稳定性,代码显式向 $S_w$ 添加了
1e-4 * eye(n_features) 的正则化项。这是一个关键的工程优化,避免了在特征数接近样本数或特征共线时程序崩溃。 - 特征提取:利用MATLAB高效的
eig(Sb, Sw) 函数求解广义特征值问题。特征值被降序排列,根据预设的目标维度(类别数 - 1),选取前 $k$ 个最大的特征值对应的特征向量组成投影矩阵 $W$。 - 空间转换:计算训练集在低维空间中的类中心(Centroids),作为后续分类的基准。
3. 投影与推理预测
预测模块接收测试数据和训练好的投影矩阵:
- 降维投影:通过矩阵乘法 $X_{test} times W$ 将测试数据映射到LDA子空间。
- 最小距离分类:遍历每个测试样本,计算其在低维空间中与各个训练类中心的欧氏距离,将样本归类为距离最近的类别。
4. 性能评估与可视化
- 准确率计算:对比预测标签与真实标签,计算分类精度并打印。
- 可视化展示:
*
左图(原始空间):选取原始数据的前两个特征维度进行散点图绘制,展示数据未处理前的分布状态。
*
右图(LDA空间):绘制经LDA降维后的数据分布。由于LDA最大化了类间距离并压缩了类内距离,通常能观察到更清晰的类别分离边界。
*
错误标记:图表中会自动检测分类错误的样本,并使用红色的 "x" 符号进行高亮标记,便于分析误判情况。