Scikit-learn 中文教程

第二部分:Scikit-learn 核心基础
第 3 章 Scikit-learn 核心设计与 API 体系
第 4 章 数据集模块与数据划分
第三部分:数据预处理与特征工程
第 5 章 数据预处理核心模块(sklearn.preprocessing)
第 6 章 特征工程:提取、选择与构建
第四部分:模型评估与验证
第 7 章 模型评估指标(按任务类型划分)
第 8 章 模型验证与超参数调优
第五部分:Scikit-learn 核心算法模块
第 9 章 有监督学习:分类算法
第 10 章 有监督学习:回归算法
第 11 章 无监督学习:聚类与密度算法
第 12 章 半监督学习与其他常用算法
第八部分:性能优化与问题解决
第 18 章 Scikit-learn 性能优化
第 19 章 Scikit-learn 常见问题与解决方案

8.4 模型选择与融合

Scikit-learn 高级教程:模型选择与融合 - 交叉验证、投票法与堆叠法

Scikit-learn 中文教程

本教程章节深入讲解Scikit-learn中的模型选择与融合技术,涵盖基于交叉验证的性能横向比较、投票法基础应用以及堆叠法分层融合策略,适合新手学习提升机器学习项目性能。

推荐工具
PyCharm专业版开发必备

功能强大的Python IDE,提供智能代码补全、代码分析、调试和测试工具,提高Python开发效率。特别适合处理列表等数据结构的开发工作。

了解更多

模型选择与融合:提升机器学习性能的关键技术

引言

在机器学习项目中,单一模型可能无法在所有情况下都表现出色。通过模型选择,我们可以找到最适合数据的算法;而模型融合则能结合多个模型的优势,提高预测准确性和鲁棒性。本章节将引导您使用Scikit-learn进行模型选择与融合,包括基于交叉验证的模型对比、投票法基础以及堆叠法高级融合,内容设计为新学习者友好,逐步解释核心概念和实操代码。

基于交叉验证的模型对比(不同算法的性能横向比较)

交叉验证是一种稳健的模型评估方法,它通过多次划分数据来减少过拟合风险,帮助公平比较不同算法的性能。

为什么使用交叉验证?

  • 避免数据划分的偶然性:交叉验证多次训练和测试模型,提供更可靠的性能估计。
  • 适合小数据集:特别适用于数据量有限的情况。

如何使用Scikit-learn进行交叉验证模型对比?

Scikit-learn提供了cross_val_score函数来评估模型性能。以下示例比较三种常见分类算法:逻辑回归、决策树和支持向量机。

示例代码:

from sklearn.model_selection import cross_val_score
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC

# 假设X和y是预处理好的特征和标签数据
X = # 特征数据
y = # 标签数据

# 初始化模型
log_reg = LogisticRegression()
dec_tree = DecisionTreeClassifier()
svc = SVC()

# 使用5折交叉验证评估每个模型
cv_scores_log = cross_val_score(log_reg, X, y, cv=5)
cv_scores_tree = cross_val_score(dec_tree, X, y, cv=5)
cv_scores_svc = cross_val_score(svc, X, y, cv=5)

# 输出平均交叉验证分数进行比较
print("Logistic Regression 平均交叉验证分数:", cv_scores_log.mean())
print("Decision Tree 平均交叉验证分数:", cv_scores_tree.mean())
print("SVC 平均交叉验证分数:", cv_scores_svc.mean())

通过比较平均分数,您可以选择性能最佳的算法作为基础模型。

模型融合基础:投票法(VotingClassifier/VotingRegressor)

投票法是一种简单有效的模型融合技术,通过组合多个基础模型的预测来提高整体性能。

投票法简介

  • 硬投票:基于多数票决定最终预测,适合分类问题。
  • 软投票:基于模型预测的概率加权平均,适用于支持概率输出的分类器。

使用VotingClassifier和VotingRegressor

Scikit-learn的VotingClassifierVotingRegressor类支持硬投票和软投票。以下示例展示如何创建一个软投票分类器。

示例代码:

from sklearn.ensemble import VotingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 初始化基础模型
log_reg = LogisticRegression()
svc = SVC(probability=True)  # 启用概率输出以支持软投票
dec_tree = DecisionTreeClassifier()

# 创建投票分类器
voting_clf = VotingClassifier(
    estimators=[('lr', log_reg), ('svc', svc), ('dt', dec_tree)],
    voting='soft'  # 使用软投票
)

# 训练模型
voting_clf.fit(X_train, y_train)

# 评估性能
accuracy = voting_clf.score(X_test, y_test)
print("投票分类器在测试集上的准确率:", accuracy)

对于回归问题,可以使用VotingRegressor类似操作,只需替换为回归模型如线性回归或随机森林。

堆叠法(StackingClassifier/StackingRegressor:多模型分层融合)

堆叠法是一种更高级的融合技术,通过分层学习来优化预测,第一层模型的输出作为第二层模型的输入。

堆叠法原理

  • 第一层(基础学习器):多个不同的模型生成初步预测。
  • 第二层(元学习器):一个模型(如线性回归或随机森林)学习如何组合基础学习器的预测。

使用StackingClassifier和StackingRegressor

Scikit-learn的StackingClassifierStackingRegressor实现了堆叠法。以下示例展示如何构建一个堆叠分类器。

示例代码:

from sklearn.ensemble import StackingClassifier, RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier

# 初始化基础学习器列表
base_learners = [
    ('lr', LogisticRegression()),
    ('svc', SVC()),
    ('dt', DecisionTreeClassifier())
]

# 选择元学习器
final_estimator = RandomForestClassifier()

# 创建堆叠分类器
stacking_clf = StackingClassifier(
    estimators=base_learners,
    final_estimator=final_estimator
)

# 训练和评估
stacking_clf.fit(X_train, y_train)
accuracy = stacking_clf.score(X_test, y_test)
print("堆叠分类器在测试集上的准确率:", accuracy)

堆叠法通常能捕捉更复杂的模式,但计算成本更高,适合在模型选择后进一步提升性能。

总结

模型选择与融合是机器学习中提升预测能力的关键环节。通过本章学习,您应该能够:

  • 使用交叉验证公平比较不同算法的性能。
  • 应用投票法简单融合多个模型。
  • 探索堆叠法进行高级分层融合。 Scikit-learn的简洁API使得这些技术易于实现,建议从交叉验证开始,逐步实验融合策略,结合实际数据优化模型。记住,选择合适的方法取决于数据特性和项目需求。
开发工具推荐
Python开发者工具包

包含虚拟环境管理、代码格式化、依赖管理、测试框架等Python开发全流程工具,提高开发效率。特别适合处理复杂数据结构和算法。

获取工具包