基于MATLAB的分类和回归树(CART)通用算法框架
项目介绍
本项目提供了一个在MATLAB环境下实现的分类和回归树(CART)通用算法框架。该框架能够根据输入任务的需求,自动构建决策树模型,支持执行非线性分类任务和连续变量回归任务。其核心逻辑依赖于将特征空间递归地划分为二叉树结构,通过在每个分点精确寻找最优特征与阈值,实现对复杂数据的建模。
项目不仅包含决策树的生长逻辑,更集成了复杂的后剪枝算法与模型评估体系,旨在为科研和工程应用提供一个透明、可修改且高性能的决策树实现方案。
功能特性
- 任务双模支持:内置分类(基于基尼系数 Gini Index)与回归(基于最小二乘偏差/均方误差)两套计算逻辑。
- 代价复杂度剪枝 (CCP):集成了完整的剪枝流程,通过计算剪枝系数生成子树序列。
- 交叉验证优化:利用K折交叉验证(K-Fold Cross-Validation)自动在剪枝序列中筛选最优模型,有效防止过拟合。
- 多维度停机准则:支持设置最大深度、最小分裂样本数、叶节点最小样本数以及最小增益阈值。
- 特征重要性度量:基于节点分裂时的增益贡献度,自动计算并归一化所有输入特征的重要性评分。
- 结构化可视化:提供直观的树状分支逻辑图和特征重要性柱状图。
- 纯MATLAB实现:不依赖外部工具箱,包含了自定义的交叉验证索引分配函数,具有良好的兼容性。
使用方法
- 环境配置:打开MATLAB,将项目所在文件夹添加至工作路径。
- 数据准备:在主程序入口处准备数据。目前程序默认生成一个包含4维特征的模拟数据集,用户可将其替换为自定义的矩阵或表格数据。
- 参数调节:根据任务类型(分类或回归)修改配置结构体中的标志位,并调整深度、样本量限制等超参数。
- 运行程序:执行主脚本。程序会自动完成训练集/测试集划分、完全树生长、剪枝序列生成、交叉验证选择、最终预测及绘图。
- 结果查看:控制台将输出分类准确率或回归均方根误差(RMSE),并弹出两个图形窗口展示特征权重与树结构。
系统要求
- MATLAB R2016b 或更高版本。
- 基本的标量/矩阵运算支持。
核心实现逻辑说明
主程序通过以下流程完成决策树的生命周期管理:
- 数据预处理阶段:生成四路特征的合成数据,利用正弦函数和二次项模拟复杂的非线性回归信号,并通过阈值判定生成对应的分类标签。
- 递归生长阶段:从根节点开始,程序在每一个节点尝试所有可能的特征及对应的所有唯一取值作为切分点。通过比较切分前后的不纯度(Gini或误差平方和)减少量,锁定最优分裂策略。
- 剪枝序列构造阶段:构建一颗“完全生长”的树后,通过递归遍历寻找剪枝系数 $g(t)$ 最小的内部节点,将其强行转化为叶节点,循环往复直至剩余根节点,从而得到一系列由繁到简的子树。
- 模型优选阶段:在交叉验证环节,将原始训练数据划分为多个折叠。对每一折数据分别训练并应用上述剪枝系数,计算子树在验证集上的误差,最终选择平均性能最佳的那个alpha等级。
- 推断与评估阶段:采用批量预测机制,逐个样本递归向下搜索直至叶节点获取预测值。特征重要性通过累加每个特征在所有非叶节点产生的加权增益实现。
关键算法与函数解析
- 树构建逻辑函数:作为递归核心,它负责根据节点数据状态决定是否停止生长(如达到最大深度或样本量不足),并创建包含节点值、不纯度、特征索引和阈值的结构体模型。
- 最优切分点搜索:该算法执行二分搜索。对于每个连续特征,它会遍历该特征下所有的唯一值,模拟切分效果,是计算复杂度最高但最核心的部分。
- 不纯度计算引擎:
- 分类模式:统计各类别概率的平方和,计算基尼指数,反映样本类别的随机性。
- 回归模式:计算样本值相对于其均值的平方误差和,衡量局部区域的离散程度。
- 代价复杂度计算:该部分算法通过计算 $R(t) - R(T_t)$ 与叶子节点数量变化的比值,量化剪枝带来的单位结构简化代价。
- 路径剪枝应用:根据给定的节点路径,精确定位并切断子树,同时保留节点作为叶子时的预测统计量。
- 特征重要性评估算法:算法从根节点开始向下求和。每个特征获得的权重等于该特征分裂时带来的增益乘以受影响的样本总数。
- 可视化辅助系统:利用MATLAB的绘图函数对稀疏矩阵定义的树形结构进行坐标布局,并动态标注切分逻辑(X <= threshold)与叶子节点值。