13.2 回归实战:波士顿房价预测
Scikit-learn回归实战:波士顿房价预测完整教程
本教程详细介绍使用Scikit-learn进行波士顿房价预测回归项目的全过程,涵盖数据探索、特征分析、缺失值和异常值处理、标准化、多种回归算法训练、模型性能评估与超参数优化,适合机器学习新手入门。
推荐工具
回归实战:波士顿房价预测
引言
欢迎来到Scikit-learn高级教程章节!在本章中,我们将以波士顿房价预测为例,系统学习回归任务的核心步骤。通过这个实战项目,您将掌握从数据探索到模型调优的全过程,并理解回归问题的关键注意事项。波士顿房价数据集是机器学习领域的经典案例,适合初学者练习。
1. 项目需求与数据探索(特征相关性分析)
项目需求
- 目标:使用波士顿房价数据集,基于房屋特征(如房间数、犯罪率等)预测房价(中位数)。
- 数据集:Scikit-learn内置的波士顿房价数据集,包含506个样本和13个特征。
- 任务类型:监督学习回归任务,输出是连续值(房价)。
数据探索
首先,加载数据并检查其结构:
from sklearn.datasets import load_boston
import pandas as pd
import numpy as np
# 加载波士顿房价数据集
boston = load_boston()
X = pd.DataFrame(boston.data, columns=boston.feature_names)
y = pd.Series(boston.target, name='MEDV') # MEDV 是中位数房价
# 查看数据集基本信息
print("数据集形状:", X.shape, y.shape)
print("特征名称:", boston.feature_names)
print("目标变量 (MEDV):", y.describe())
特征相关性分析
相关性分析帮助识别与房价相关的关键特征:
import matplotlib.pyplot as plt
import seaborn as sns
# 计算相关系数矩阵
corr_matrix = X.corrwith(y)
print("特征与房价的相关系数:")
print(corr_matrix.sort_values(ascending=False))
# 可视化相关性热图
plt.figure(figsize=(10, 8))
sns.heatmap(X.corr(), annot=True, cmap='coolwarm', fmt='.2f')
plt.title("特征间相关性热图")
plt.show()
关键点:高相关性特征(如RM房间数)通常是重要预测因子,而低相关性或负相关性特征可能需要进一步处理。
2. 数据预处理(缺失值 / 异常值处理、标准化)
缺失值处理
波士顿房价数据集通常没有缺失值,但需检查并处理:
# 检查缺失值
print("缺失值数量:", X.isnull().sum().sum()) # 应为0,因为是内置数据集
# 如果有缺失值,可使用填充或删除,例如:
# X.fillna(X.mean(), inplace=True) # 用均值填充
异常值处理
使用箱线图识别并处理异常值:
# 识别异常值(以RM特征为例)
plt.figure(figsize=(6, 4))
sns.boxplot(x=X['RM'])
plt.title("RM特征箱线图")
plt.show()
# 处理异常值(例如,使用IQR方法)
Q1 = X['RM'].quantile(0.25)
Q3 = X['RM'].quantile(0.75)
IQR = Q3 - Q1
lower_bound = Q1 - 1.5 * IQR
upper_bound = Q3 + 1.5 * IQR
X_clean = X[(X['RM'] >= lower_bound) & (X['RM'] <= upper_bound)]
print("处理异常值后样本数:", X_clean.shape[0])
标准化
标准化(Standardization)使特征均值为0、方差为1,提高模型性能:
from sklearn.preprocessing import StandardScaler
# 标准化特征
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_clean) # X_clean是处理后的特征
print("标准化后的特征示例:", X_scaled[:5])
3. 回归算法训练(线性回归 / 随机森林 / XGBoost)
划分训练集和测试集
from sklearn.model_selection import train_test_split
# 划分数据集,80%训练,20%测试
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
print("训练集大小:", X_train.shape, "测试集大小:", X_test.shape)
线性回归
线性回归是基础回归方法:
from sklearn.linear_model import LinearRegression
# 训练线性回归模型
lr_model = LinearRegression()
lr_model.fit(X_train, y_train)
print("线性回归系数:", lr_model.coef_)
随机森林回归
随机森林适用于非线性关系:
from sklearn.ensemble import RandomForestRegressor
# 训练随机森林回归模型
rf_model = RandomForestRegressor(n_estimators=100, random_state=42)
rf_model.fit(X_train, y_train)
XGBoost回归
XGBoost是强大的梯度提升算法:
import xgboost as xgb
# 训练XGBoost回归模型
xgb_model = xgb.XGBRegressor(n_estimators=100, learning_rate=0.1, random_state=42)
xgb_model.fit(X_train, y_train)
4. 模型评估(RMSE/R²)与超参数调优
模型评估
使用RMSE(均方根误差)和R²(决定系数)评估性能:
from sklearn.metrics import mean_squared_error, r2_score
import numpy as np
# 评估函数
def evaluate_model(model, X_test, y_test):
y_pred = model.predict(X_test)
rmse = np.sqrt(mean_squared_error(y_test, y_pred))
r2 = r2_score(y_test, y_pred)
return rmse, r2
# 评估三个模型
print("线性回归 - RMSE: {:.2f}, R²: {:.2f}".format(*evaluate_model(lr_model, X_test, y_test)))
print("随机森林 - RMSE: {:.2f}, R²: {:.2f}".format(*evaluate_model(rf_model, X_test, y_test)))
print("XGBoost - RMSE: {:.2f}, R²: {:.2f}".format(*evaluate_model(xgb_model, X_test, y_test)))
超参数调优
使用网格搜索优化随机森林参数:
from sklearn.model_selection import GridSearchCV
# 定义参数网格
param_grid = {
'n_estimators': [50, 100, 150],
'max_depth': [None, 10, 20],
'min_samples_split': [2, 5, 10]
}
# 网格搜索
grid_search = GridSearchCV(RandomForestRegressor(random_state=42), param_grid, cv=5, scoring='r2')
grid_search.fit(X_train, y_train)
print("最佳参数:", grid_search.best_params_)
print("最佳R²分数:", grid_search.best_score_)
5. 项目总结:回归任务的核心注意事项
- 数据质量是关键:确保数据清洁,处理缺失值和异常值,避免噪声影响模型。
- 特征工程很重要:相关性分析帮助选择特征,标准化提高算法收敛性。
- 算法选择需明智:线性回归适合线性关系,随机森林和XGBoost处理复杂模式,但可能过拟合。
- 评估指标要合理:RMSE衡量误差大小,R²表示模型解释方差的比例,结合使用以获得全面评估。
- 超参数调优优化性能:使用交叉验证(如GridSearchCV)避免过拟合,提升模型泛化能力。
- 项目可扩展性:在实战中,考虑数据更新、模型部署和监控,持续改进模型。
通过本章教程,您已掌握Scikit-learn回归项目的全流程。动手实践这些步骤,并尝试应用到其他数据集,以深化理解。祝您学习愉快!
开发工具推荐