4.3 数据划分与重采样
Scikit-learn数据划分与重采样教程:从基础到高级
本教程详细讲解Scikit-learn中的数据划分和重采样技术,覆盖训练集测试集划分、分层抽样、交叉验证方法以及处理类别不平衡的重采样策略,帮助初学者和工程师提升模型性能。
Scikit-learn数据划分与重采样指南
在机器学习项目中,数据划分和重采样是确保模型泛化能力和处理数据不平衡的关键步骤。本教程将带你了解Scikit-learn中常用的方法,从简单的训练集/测试集划分到高级的重采样技术。
为什么需要数据划分和重采样?
数据划分将数据集分成训练集和测试集,以评估模型在未见数据上的表现,防止过拟合。重采样则用于调整数据集中的类别分布,解决类别不平衡问题,从而提高少数类的预测准确性。
训练集/测试集划分:train_test_split()
Scikit-learn提供train_test_split()函数,用于快速划分数据。常用参数包括:
test_size:指定测试集的比例,如0.2表示20%的数据作为测试集。random_state:设置随机种子,确保划分结果可重复。stratify:进行分层抽样,保持类别分布一致。
示例代码:
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
# 加载数据集
iris = load_iris()
X, y = iris.data, iris.target
# 划分数据集,使用20%作为测试集,设置随机种子,并分层抽样
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
print(f"训练集大小: {len(X_train)}, 测试集大小: {len(X_test)}")
解释: 这里,stratify=y确保训练集和测试集中每个类别的比例与原始数据集相同,适用于不平衡数据。
分层抽样:解决类别不平衡问题
分层抽样通过stratify参数实现,在划分数据时保持类别分布。这对于分类问题尤其重要,避免某些类别在测试集中缺失或过少。
应用场景: 当数据集中某些类别样本较少时,使用分层抽样可以确保模型训练和评估时所有类别都被充分代表。
交叉验证数据划分
交叉验证是一种更稳健的评估方法,将数据多次划分以平均模型性能。Scikit-learn提供多种交叉验证划分器:
KFold
- 将数据分成k个折叠,每次使用k-1个折叠训练,1个折叠测试,重复k次。
- 适合数据分布均匀的情况。
StratifiedKFold
- 类似KFold,但在划分时进行分层,保持类别分布。
- 推荐用于分类问题,特别是类别不平衡时。
ShuffleSplit
- 随机划分数据多次,每次划分时打乱数据,适合大型数据集或需要多次随机采样的情况。
示例代码:
from sklearn.model_selection import KFold, StratifiedKFold, ShuffleSplit
# KFold示例
kf = KFold(n_splits=5, shuffle=True, random_state=42)
for train_index, test_index in kf.split(X, y):
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
# 训练和评估模型
# StratifiedKFold示例
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
for train_index, test_index in skf.split(X, y):
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
# ShuffleSplit示例
ss = ShuffleSplit(n_splits=5, test_size=0.2, random_state=42)
for train_index, test_index in ss.split(X):
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
提示: 使用cross_val_score()函数可以简化交叉验证过程。
重采样方法:处理极端类别不平衡
当数据集中类别极度不平衡时(如欺诈检测),重采样技术可以调整数据集。Scikit-learn本身不直接提供重采样函数,但可以与imbalanced-learn库结合使用。
上采样(Oversampling)
- 增加少数类样本的数量,例如通过复制或生成新样本(如SMOTE算法)。
- 优点:保留所有信息;缺点:可能导致过拟合。
下采样(Undersampling)
- 减少多数类样本的数量,以平衡类别分布。
- 优点:计算效率高;缺点:可能丢失重要信息。
示例使用imbalanced-learn:
# 安装库: pip install imbalanced-learn
from imblearn.over_sampling import SMOTE
from imblearn.under_sampling import RandomUnderSampler
# 上采样示例
smote = SMOTE(random_state=42)
X_resampled, y_resampled = smote.fit_resample(X_train, y_train)
# 下采样示例
rus = RandomUnderSampler(random_state=42)
X_resampled, y_resampled = rus.fit_resample(X_train, y_train)
最佳实践: 在交叉验证循环内进行重采样,避免数据泄露。
总结
- 使用
train_test_split()进行基础划分,结合stratify处理不平衡数据。 - 交叉验证(如KFold、StratifiedKFold)提供更可靠的模型评估。
- 对于极端不平衡,采用重采样方法(上采样或下采样)优化数据集。
- 总是设置
random_state以确保结果可重复。
通过本教程,你应该能熟练应用Scikit-learn中的数据划分和重采样技术,提升机器学习项目的效果。实践中,根据具体问题选择合适的组合方法。