网站首页 > 技术文章 正文
这是Python机器学习系列原创文章,我的第204篇原创文章。
一、引言
对于表格数据,一套完整的机器学习建模流程如下:
针对不同的数据集,有些步骤不适用即不需要做,其中橘红色框为必要步骤,由于数据质量较高,本文有些步骤跳过了,跳过的步骤将单独出文章总结!同时欢迎大家关注翻看我之前的一些相关文章。
GradientBoostingClassifier是一种基于梯度提升算法的分类器,它是scikit-learn库中的一个类。梯度提升是一种集成学习方法,通过组合多个弱学习器(通常是决策树,梯度提升决策树GBDT)来构建一个更强大的分类器。梯度提升模型的基本思想是利用梯度下降来最小化损失函数,以逐步优化模型的预测能力。在每一轮迭代中,模型会计算当前模型对样本的预测值与实际值之间的残差,然后使用一个新的弱学习器来拟合这个残差。通过迭代地拟合残差,每个弱学习器都会以一定的学习率加入到模型中,最终得到一个强大的集成模型。本文将实现基于心脏疾病数据集建立梯度提升模型对心脏疾病患者进行分类预测的完整过程。
二、实现过程
导入必要的库
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_curve
from sklearn.metrics import auc
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
1、准备数据
data = pd.read_csv(r'Dataset.csv')
df = pd.DataFrame(data)
df:
数据基本信息:
print(df.head())
print(df.info())
print(df.shape)
print(df.columns)
print(df.dtypes)
cat_cols = [col for col in df.columns if df[col].dtype == "object"] # 类别型变量名
num_cols = [col for col in df.columns if df[col].dtype != "object"] # 数值型变量名
2、提取特征变量和目标变量
target = 'target'
features = df.columns.drop(target)
print(data["target"].value_counts()) # 顺便查看一下样本是否平衡
3、数据集划分
# df = shuffle(df)
X_train, X_test, y_train, y_test = train_test_split(df[features], df[target], test_size=0.2, random_state=0)
4、模型的构建与训练
# 模型的构建与训练
model = GradientBoostingClassifier()
model.fit(X_train, y_train)
参数详解:
from sklearn.ensemble import GradientBoostingClassifier
# 全部参数
GradientBoostingClassifier(loss='log_loss',
learning_rate=0.1,
n_estimators=100,
subsample=1.0,
criterion='friedman_mse',
min_samples_split=2,
min_samples_leaf=1,
min_weight_fraction_leaf=0.0,
max_depth=3,
min_impurity_decrease=0.0,
init=None,
random_state=None,
max_features=None,
verbose=0,
max_leaf_nodes=None,
warm_start=False,
validation_fraction=0.1,
n_iter_no_change=None,
tol=0.0001,
ccp_alpha=0.0)
- loss:损失函数的类型。默认为deviance,表示使用对数似然损失函数进行分类。可以选择exponential,表示使用指数损失函数进行分类。
- learning_rate:学习率,控制每个弱学习器的贡献。较小的学习率会使模型收敛得更慢,但可能会获得更好的性能。默认为0.1。
- n_estimators:弱学习器(决策树)的数量。默认为100。
- subsample:用于训练每个弱学习器的样本子集的比例。默认为1.0,表示使用全部样本。可以设置小于1.0的值来降低方差,防止过拟合。
- criterion:决策树节点分裂的标准。默认为friedman_mse,表示使用Friedman均方误差作为分裂标准。可以选择mse,表示使用均方误差,或mae,表示使用平均绝对误差。
- max_depth:决策树的最大深度。默认为3。增加深度可以增加模型的复杂度,但也容易导致过拟合。
- min_samples_split:决策树节点分裂所需的最小样本数。默认为2。如果某个节点的样本数少于该值,则不会再进行分裂。
- min_samples_leaf:叶节点所需的最小样本数。默认为1。如果叶节点的样本数少于该值,则不会进行进一步的分裂。
- max_features:每个决策树节点考虑的特征数量。可以是整数、浮点数或字符串。默认为None,表示考虑所有特征。可以选择sqrt,表示考虑特征数量的平方根,或log2,表示考虑特征数量的对数。
- random_state:随机种子。可以用于重现实验结果。
- verbose:控制训练过程中的输出信息的详细程度。默认为0,表示不输出任何信息。较大的值会增加输出信息的数量。
5、模型的推理与评价
y_pred = model.predict(X_test)
y_scores = model.predict_proba(X_test)
acc = accuracy_score(y_test, y_pred) # 准确率acc
cm = confusion_matrix(y_test, y_pred) # 混淆矩阵
cr = classification_report(y_test, y_pred) # 分类报告
fpr, tpr, thresholds = roc_curve(y_test, y_scores[:, 1], pos_label=1) # 计算ROC曲线和AUC值,绘制ROC曲线
roc_auc = auc(fpr, tpr)
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc="lower right")
plt.show()
cm:
cr:
ROC:
三、小结
本文利用scikit-learn(一个常用的机器学习库)实现了基于心脏疾病数据集建立梯度提升模型对心脏疾病患者进行分类预测的完整过程。
作者简介:
读研期间发表6篇SCI数据算法相关论文,目前在某研究院从事数据算法相关研究工作,结合自身科研实践经历不定期持续分享关于Python、数据分析、特征工程、机器学习、深度学习、人工智能系列基础知识与案例。需要数据源码的朋友关注gzh:数据杂坛,或点击原文链接,联系作者。
原文链接:
猜你喜欢
- 2024-12-20 CatBoost、LightGBM、XGBoost,这些算法你都了解吗?
- 2024-12-20 前四场全对!博主通过AI预测巴西将夺世界杯冠军
- 2024-12-20 算法金 | 决策树、随机森林、Bagging、Adaboost、GBDT、XGBoost
- 2024-12-20 数据中台:从0到1打造一个离线推荐系统
- 2024-12-20 GBDT——梯度提升树算法详解 梯度提升数
你 发表评论:
欢迎- 最近发表
- 标签列表
-
- oraclesql优化 (66)
- 类的加载机制 (75)
- feignclient (62)
- 一致性hash算法 (71)
- dockfile (66)
- 锁机制 (57)
- javaresponse (60)
- 查看hive版本 (59)
- phpworkerman (57)
- spark算子 (58)
- vue双向绑定的原理 (68)
- springbootget请求 (58)
- docker网络三种模式 (67)
- spring控制反转 (71)
- data:image/jpeg (69)
- base64 (69)
- java分页 (64)
- kibanadocker (60)
- qabstracttablemodel (62)
- java生成pdf文件 (69)
- deletelater (62)
- com.aspose.words (58)
- android.mk (62)
- qopengl (73)
- epoch_millis (61)
本文暂时没有评论,来添加一个吧(●'◡'●)