CART决策树算法的MATLAB实现与可视化系统
项目简介
本项目是一个基于MATLAB环境开发的分类与回归树(CART)算法完整实现。该系统不依赖任何高级机器学习工具箱,而是通过底层代码完全手写了决策树的核心逻辑,包括树的生长、代价复杂度剪枝(Cost-Complexity Pruning, CCP)、交叉验证以及模型可视化。
该项目旨在提供一个透明、可解释且易于理解的机器学习算法实现,特别展示了如何通过剪枝策略平衡模型的复杂度与泛化能力,适用于算法研究、教学演示验证以及分类预测任务。
主要功能特性
- 全流程纯手写实现:不调用
fitctree 等现成函数,从零构建节点分裂、递归生长及剪枝逻辑。 - 自动化模拟数据生成:系统内置数据生成模块,创建一个包含3个类别、4个特征(含2个噪声特征)的复杂数据集,模拟线性不可分情况,无需外部数据文件即可直接运行。
- 基尼系数(Gini Impurity)分裂:采用基尼不纯度作为节点分裂的评估指标,支持连续数值特征的最佳分割点搜索。
- 代价复杂度剪枝(CCP):实现了完整的CCP算法,自动计算剪枝路径中的有效Alpha值序列,通过最小化代价复杂度函数来生成最优子树序列。
- K折交叉验证(K-Fold CV):内置5折交叉验证机制,通过在验证集上的表现自动筛选出最佳的正则化参数(Alpha),有效防止模型过拟合。
- 自定义可视化引擎:开发了专门的递归绘图功能,能够生成清晰的树状拓扑图,展示节点分裂条件、判定路径及最终分类结果。
系统要求
- MATLAB版本:R2016b及以上(推荐)。
- 工具箱依赖:仅需基础MATLAB环境及Statistics and Machine Learning Toolbox(用于基础统计函数如
histcounts, rng, randn, randperm),无需高级深度学习或特定预测工具箱。
详细算法实现与逻辑说明
本项目通过一个主入口函数串联起数据处理、算法建模与评估的全过程,具体实现逻辑如下:
1. 数据准备与预处理
程序首先固定随机种子以确保结果可复现。随后生成一个包含300个样本的模拟数据集,分为3个类别。特征设计包含两部分:前两个特征具有特定的分布偏移,模拟真实分类信息;后两个特征为纯随机噪声,用于测试算法的抗噪能力。数据按7:3的比例随机划分为训练集和测试集。
2. 决策树完全生长(Full Growth)
在第一阶段,算法构建一棵完全生长的树。
- 递归构建:从根节点开始,递归地寻找最佳分裂特征。
- 分裂标准:遍历所有特征及其所有可能的切分点(相邻值的中间点),计算加权基尼系数。选择使加权基尼系数最小(即纯度提升最大)的特征和阈值进行二叉分裂。
- 停止条件:当节点内样本属于同一类别、样本数低于设定的最小叶节点阈值(默认为3),或达到最大深度(默认为10)时,停止分裂并标记为叶节点。
3. 代价复杂度剪枝(CCP)
这是本项目的核心亮点。算法不是简单地预剪枝,而是先让树完全生长,再通过CCP算法进行修剪。
- 最弱连接搜索:定义节点的误差代价 $R(t)$ 和以该节点为根的子树误差代价 $R(T_t)$。计算每个内部节点的表面误差增益 $g(t)$,该值衡量了剪掉该分支后,每减少一个叶节点所带来的误差增加量。
- 剪枝路径生成:迭代地寻找全局最小的 $g(t)$ 对应的节点(最弱连接)进行剪枝,将其变为叶节点。重复此过程直到整棵树剪枝为根节点,从而得到一系列嵌套的子树序列和对应的Alpha(罚分参数)序列。
4. 交叉验证与模型选择
为了从上述子树序列中选出泛化能力最强的一个,系统执行5折交叉验证。
- 过程:将训练数据分为5份,轮流作为验证集。在其余数据上重新训练树并在对应Alpha水平下进行测试。
- 最佳Alpha选取:统计各Alpha在验证集上的平均准确率,选择准确率最高且模型尽可能简单的Alpha值。
- 最终模型:根据选定的最佳Alpha,从全量训练数据的剪枝序列中提取出最终的最优子树。
5. 模型评估与可视化
- 性能指标:使用保留的30%测试集对最优子树进行测试,计算分类准确率并输出混淆矩阵,直观展示各类别的误判情况。
- 拓扑图绘制:利用递归绘图算法,根据节点的层级深度动态计算坐标。图中以连线表示父子关系,以文本形式标注"Yes/No"分支逻辑,叶节点展示最终的预测类别颜色标识。图表标题自动标注模型的最终测试准确率及选用的Alpha参数。
核心函数库分析
系统内部封装了以下核心子函数以支撑算法运行:
build_tree (构建树)
算法的主要递归函数。它接收数据集和参数,计算当前节点的基尼系数。如果未满足停止条件,它会通过双重循环遍历所有特征和切分点,找到最优分裂方案,并递归调用自身生成左右子树。
get_pruning_path (获取剪枝路径)
实现了CCP策略的主循环。它不断调用 find_weakest_link 来识别并剪除树中当前最不重要的分支,记录每一次剪枝后的子树结构和对应的Alpha阈值。
find_weakest_link (寻找最弱连接)
通过递归遍历树的所有节点,计算每个非叶节点的 $g(t)$ 值。使用了全局变量辅助递归过程中的最小值搜索,确保能准确定位到全局 $g(t)$ 最小的节点ID。
cross_validate_alpha (Alpha交叉验证)
实现了K折验证逻辑。它不仅负责数据切分,还包含了一个关键逻辑:在验证过程中,将训练好的大树根据目标Alpha序列进行匹配剪枝,以评估特定复杂度参数下的模型表现。
draw_tree_structure (树结构绘制)
可视化的核心。采用深度优先遍历策略,根据当前节点的坐标动态计算子节点的偏移量(随深度增加偏移量减半),利用MATLAB的绘图句柄绘制节点连接线和逻辑判断文本,直观展现决策树的层级结构。
calculate_gini (基尼系数计算)
统计辅助函数,计算给定标签集合的概率分布平方和,返回 $1 - sum p_i^2$,作为衡量节点不纯度的标准。
使用方法
- 确保MATLAB环境已按照就绪。
- 将包含所有代码的主脚本文件保存(例如
main.m)至当前工作目录。 - 在MATLAB命令行窗口输入函数名或点击运行按钮。
- 程序将自动输出:
* 构建过程的状态提示。
* 最优Alpha值及对应的节点数量。
* 测试集准确率和混淆矩阵。
* 一个独立的图形窗口,显示最终决策树的结构图。