基于有导师学习神经网络的花朵识别系统
项目简介
本项目是一个基于MATLAB开发环境的构建的花朵自动分类与识别系统。利用MATLAB强大的数值计算能力与神经网络工具箱,项目演示了有导师学习(Supervised Learning)的全过程。系统通过模拟生成具有特定统计规律的花朵特征数据,训练一个BP(Back Propagation)神经网络模型,从而建立花朵形态特征与种类之间的映射关系。该代码旨在为初学者提供一个标准、规范且逻辑清晰的深度学习入门案例,涵盖了数据生成、预处理、模型构建、训练、评估及可视化的完整工作流。
功能特性
- 独立运行的数据模拟:无需外部数据文件,代码内部利用高斯分布算法自动生成类似鸢尾花(Iris)数据集的特征矩阵(4维特征,3种类别),保证了程序的便携性和可重复性。
- 完整的神经网络工作流:实现了从数据清洗、数据集划分(训练/验证/测试)、网络构建到模型训练的全套流程。
- 多维度的性能评估:通过计算分类准确率、绘制训练性能曲线(MSE)、混淆矩阵(Confusion Matrix)以及ROC曲线,全方位评估模型性能。
- 直观的结果可视化:引入主成分分析(PCA)算法对高维数据进行降维,并使用散点图直观展示测试集的真实分类分布与预测分类分布的对比。
- 由浅入深的教学演示:代码包含单独的样本预测模块,演示了如何将训练好的模型应用于未知数据的推理。
系统要求
- 软件版本:MATLAB R2014b 或更高版本(推荐使用较新版本以获得更好的图形支持)。
- 工具箱依赖:
* Deep Learning Toolbox (原 Neural Network Toolbox):用于构建和训练神经网络。
* Statistics and Machine Learning Toolbox:用于PCA降维分析和数据生成。
使用方法
- 启动MATLAB软件。
- 将当前工作目录切换至源码所在文件夹。
- 在命令行窗口输入主函数名称
main 并回车,或直接点击编辑器中的“运行”按钮。 - 程序将自动执行数据生成、模型训练,并弹出多个图形窗口展示训练过程和评估结果。
- 命令行窗口将输出训练进度、分类准确率以及模拟样本的预测结果。
代码功能与实现逻辑分析
本项目的主程序 main.m 通过模块化的方式实现了以下具体逻辑:
1. 数据生成与导入
程序首先初始化环境并设置随机种子(
rng(42))以确保结果可复现。
- 特征模拟:使用
bsxfun 函数结合 randn(标准正态分布),基于预设的均值(mu)和标准差(sigma)生成三类花朵的特征数据。每类生成50个样本,每个样本包含4个特征维度(模拟花萼长度、宽度等)。 - 标签构建:对应生成的数据,构建目标类别矩阵。采用了 One-Hot 编码(独热编码)格式,即类别1表示为
[1;0;0],类别2表示为 [0;1;0],以此类推。 - 数据合并:将三类数据合并为特征矩阵
X 和标签矩阵 Y。
2. 数据集划分与预处理
- 随机打乱:使用
randperm 函数生成随机索引,将原始数据和标签同步打乱,消除数据顺序对训练的影响。 - 集划分:通过
dividerand 函数将数据集按 70%训练集、15%验证集、15%测试集 的比例进行划分。虽然神经网络工具箱内部会自动调用这些索引,但代码中显式定义了比例以便控制。
3. 神经网络构建
- 模型选择:使用
patternnet 函数创建一个专门用于模式识别的前馈神经网络。 - 网络结构:设置隐藏层神经元数量为 10 个,输入层维度由数据自动决定(4维),输出层维度由标签决定(3维)。
- 参数配置:
* 训练函数默认使用量化共轭梯度法(
trainscg)。
* 最大迭代次数(Epochs)设为 1000。
* 目标误差(Goal)设为 1e-5。
* 开启训练窗口显示,便于实时观察梯度下降过程。
4. 网络训练
调用
train 函数,传入网络对象
net、打乱后的特征数据
X_shuffled 和目标标签
Y_shuffled。神经网络根据设定的划分比例自动拆分数据进行反向传播训练,更新权重和阈值。
5. 测试与预测
- 全集仿真:使用训练好的网络对所有数据进行预测。
- 测试集提取:利用训练记录
tr 中的 testInd 索引,专门提取测试集部分的真实标签和预测结果。 - 结果解码:使用
max 函数将 One-Hot 编码的概率输出转换为具体的类别索引(1, 2, 3)。 - 准确率计算:对比测试集的预测类别与真实类别,计算并输出分类准确率百分比。
6. 结果可视化
代码生成了四组关键图表:
- 性能曲线图:调用
plotperform,展示均方误差(MSE)随训练迭代次数下降的趋势,用于判断模型是否收敛或过拟合。 - 混淆矩阵图:调用
plotconfusion,可视化各类别的分类正确数和误判数,直观展示模型在各类上的表现。 - ROC曲线图:调用
plotroc,展示每个类别的接收者操作特征曲线,评估分类器的灵敏度和特异度。 - PCA降维可视化图:
* 利用
pca 函数将4维特征数据降维至2维。
* 使用
gscatter 函数分别绘制测试集的“真实类别分布”和“预测类别分布”对比图。这使得高维数据的分类效果可以在二维平面上直观展现。
7. 单个样本模拟测试
为了演示实际应用场景,代码最后构造了一个特定的未知样本向量(特征接近类别1)。调用训练好的网络进行预测,并解析输出的概率向量,打印出预测的类别归属及置信度。
关键算法与函数说明
- patternnet: MATLAB用于创建模式识别网络的专用函数,通常采用 Tansig(双曲正切S型)传递函数作为隐藏层激活函数,Softmax 作为输出层传输函数。
- train: 神经网络通用的训练驱动函数,根据配置的
trainFcn(本项目中隐含为 trainscg)执行优化算法。 - pca (Principal Component Analysis): 主成分分析算法,用于提取数据中的主要特征分量,本项目中用于将4维花朵数据压缩到2维以便绘图。
- bsxfun: 用于对应元素进行二元运算的函数,本项目中高效地实现了基于高斯分布参数的数据批量生成。
- One-Hot Encoding: 一种分类标签的编码方式,能够让神经网络输出各个类别的概率分布。