网站首页 > 技术文章 正文
Python 机器学习中,CART(Classification And Regression Trees)算法用于构建决策树,用于分类和回归任务。剪枝(Pruning)是一种避免决策树过拟合的技术,通过减少树的大小来提高模型的泛化能力。CART剪枝分为预剪枝和后剪枝两种主要方式。
参考文档:https://www.cjavapy.com/article/3264/
1、预剪枝(Pre-Pruning)
预剪枝涉及在决策树完全生成之前停止树的增长。可以通过设置一些停止条件来实现,
1)树达到预定的最大深度(max_depth)
2)节点中的样本数量少于预定阈值(min_samples_split
3)分割后的节点的信息增益小于某个阈值,
4)节点中样本的纯度(比如,用基尼指数或熵测量)已经足够高。
预剪枝简单易实现,但可能过于保守,有时会导致模型欠拟合。
2、后剪枝(Post-Pruning)
后剪枝,也称为剪枝,是在决策树完全生成之后进行的。它通过删除树的部分子树或节点来减少树的复杂度,选择那些能够提高交叉验证数据集准确率的剪枝。后剪枝策略包括成本复杂度剪枝(Cost Complexity Pruning)、错误率降低剪枝(Reduced Error Pruning)和最小错误剪枝(Minimum Error Pruning)。
成本复杂度剪枝(Cost Complexity Pruning)是通过最小化一个称为成本复杂度的函数来实现剪枝。这个函数是树的错误率和树的复杂度的加权和。
错误率降低剪枝(Reduced Error Pruning)是从叶节点开始,尝试移除每个节点,如果移除后对验证集的分类准确性没有影响或者有所提高,则进行剪枝。
最小错误剪枝(Minimum Error Pruning)是在每个节点上应用一个简单的启发式规则,如果剪枝不会导致错误率增加,则执行剪枝。
3、cart剪枝的作用
决策树通过递归地选择最佳属性将数据集分割,构建出一个树状的分类模型。但一个没有限制的决策树很容易过度拟合训练数据,导致模型在未知数据上的泛化能力下降。为了解决这个问题,决策树剪枝技术被提出来,以提高决策树模型的泛化能力。通过剪掉不必要的节点,减少模型对训练数据中噪声的拟合,从而提高模型在未见数据上的泛化能力。剪枝后的决策树模型更简洁,易于理解和解释,有利于提高模型的可解释性。简化后的模型在预测时计算量更小,预测速度更快。剪枝是提高决策树模型泛化能力和效率的重要技术之一,是决策树算法中不可或缺的一部分。在实际应用中,通过适当选择剪枝策略和参数,可以大幅提升模型的性能。
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# 加载iris数据集
iris = load_iris()
X = iris.data
y = iris.target
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 训练一个决策树模型(未剪枝)
clf = DecisionTreeClassifier(random_state=42)
clf.fit(X_train, y_train)
# 预测测试集
y_pred = clf.predict(X_test)
# 评估模型
accuracy_without_pruning = accuracy_score(y_test, y_pred)
# 训练一个决策树模型(使用代价复杂度剪枝)
clf_pruned = DecisionTreeClassifier(random_state=42, ccp_alpha=0.01) # ccp_alpha是剪枝的复杂度参数
clf_pruned.fit(X_train, y_train)
# 预测测试集
y_pred_pruned = clf_pruned.predict(X_test)
# 评估模型
accuracy_with_pruning = accuracy_score(y_test, y_pred_pruned)
print(accuracy_without_pruning, accuracy_with_pruning)
4、cart剪枝的应用
CART(Classification and Regression Trees)算法用于构建决策树,既可以用于分类问题也可以用于回归问题。一棵完全生长的决策树往往会过于复杂,导致过拟合,即在训练数据上表现很好但在未见过的数据上表现不佳。为了解决这个问题,可以采用剪枝(pruning)技术来简化决策树,提高模型的泛化能力。CART 决策树的构建过程采用贪心算法,不断地划分数据集,直到满足停止条件。DecisionTreeClassifier 是 scikit-learn 中用于解决分类问题的决策树算法实现。常用参数如下,
参数 | 描述 |
criterion | 用于衡量分裂质量的函数。 支持的标准有 'gini'(基尼不纯度) 和 'entropy'(信息增益)。 |
splitter | 选择每个节点处分裂策略的策略。 支持的策略有 'best'(选择最佳分裂) 和 'random'(选择最佳随机分裂)。 |
max_depth | 树的最大深度。如果为 None, 则节点扩展直到所有叶子都是纯净的, 或直到所有叶子 包含小于 min_samples_split 样本的数量。 |
min_samples_split | 分裂内部节点所需的最小样本数。 |
min_samples_leaf | 一个叶节点所需的最小样本数。 |
max_features | 寻找最佳分裂时要考虑的特征数量。 可以是整数、浮点数、'auto'、'sqrt' 或 'log2'。 |
random_state | 控制估计器的随机性。 即使当 splitter 设置为 'best' 时, 每次分裂时特征也总是随机置换的。 |
max_leaf_nodes | 以最佳先行方式生长树,最大叶节点数。 如果为 None,则叶节点数量不受限制。 |
min_impurity_decrease | 如果此分裂导致不纯度的减少大于或等于此值, 则会分裂节点。 |
class_weight | 与类相关的权重,形式为 {class_label: weight}。 如果未给出,则所有类假定有权重一。 |
ccp_alpha | 用于最小化代价复杂度剪枝的复杂度参数。 将选择代价复杂度最大且小于 ccp_alpha 的子树。 |
使用代码,
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
# 加载数据集
iris = load_iris()
X = iris.data
y = iris.target
# 分割数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 训练决策树模型
tree = DecisionTreeClassifier(random_state=42)
tree.fit(X_train, y_train)
# 成本复杂度剪枝参数
path = tree.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas, impurities = path.ccp_alphas, path.impurities
# 对每个ccp_alpha训练一个决策树并评估其性能
trees = []
for ccp_alpha in ccp_alphas:
tree = DecisionTreeClassifier(random_state=0, ccp_alpha=ccp_alpha)
tree.fit(X_train, y_train)
trees.append(tree)
# 选择最佳的ccp_alpha值(可根据测试集性能来选择)
# 这里简化了选择过程,实际应用中应该使用交叉验证等方法
# 可视化决策树
plt.figure(figsize=(20,10))
plot_tree(trees[-1], filled=True, feature_names=iris.feature_names, class_names=iris.target_names)
plt.draw()
plt.show()
参考文档:https://www.cjavapy.com/article/3264/
猜你喜欢
- 2024-11-12 电力系统领域,电力系统暂态稳定判别方法
- 2024-11-12 机器学习“司马家族”——树族 机器学习实战树回归
- 2024-11-12 大白话人工智能算法-第27节决策树系列之预剪枝和后减枝(6)
- 2024-11-12 机器学习之图解 GBDT 的构造和预测过程
- 2024-11-12 机器学习算法之随机森林算法通俗易懂版本
- 2024-11-12 决策树之 GBDT 算法 - 回归部分 gbdt和决策树
- 2024-11-12 大数据:如何用决策树解决分类问题
- 2024-11-12 几种特征选择方法的比较,孰好孰坏?
- 2024-11-12 决策树算法之随机森林 决策树和随机森林预测结果
- 2024-11-12 3分钟掌握机器学习中的决策树 机器学习和深度学习决策树
你 发表评论:
欢迎- 11-13第一次养猫的人养什么品种比较合适?
- 11-13大学新生活不适应?送你舒心指南! 大学新生的不适应主要有哪些方面
- 11-13第一次倒班可能会让人感到有些不适应,以下是一些建议
- 11-13货物大小不同装柜算法有哪些?怎么算?区别有哪些?
- 11-13五大基本算法 五大基本算法是什么
- 11-13高级程序员必备:分治算法分享 分冶算法
- 11-13最快速的寻路算法 Jump Point Search
- 11-13手机实时人工智能之「三维动作识别」:每帧只需9ms
- 最近发表
- 标签列表
-
- oraclesql优化 (66)
- 类的加载机制 (75)
- feignclient (62)
- 一致性hash算法 (71)
- dockfile (66)
- 锁机制 (57)
- javaresponse (60)
- 查看hive版本 (59)
- phpworkerman (57)
- spark算子 (58)
- vue双向绑定的原理 (68)
- springbootget请求 (58)
- docker网络三种模式 (67)
- spring控制反转 (71)
- data:image/jpeg (69)
- base64 (69)
- java分页 (64)
- kibanadocker (60)
- qabstracttablemodel (62)
- java生成pdf文件 (69)
- deletelater (62)
- com.aspose.words (58)
- android.mk (62)
- qopengl (73)
- epoch_millis (61)
本文暂时没有评论,来添加一个吧(●'◡'●)