Scikit-learn 中文教程

第二部分:Scikit-learn 核心基础
第 3 章 Scikit-learn 核心设计与 API 体系
第 4 章 数据集模块与数据划分
第三部分:数据预处理与特征工程
第 5 章 数据预处理核心模块(sklearn.preprocessing)
第 6 章 特征工程:提取、选择与构建
第四部分:模型评估与验证
第 7 章 模型评估指标(按任务类型划分)
第 8 章 模型验证与超参数调优
第五部分:Scikit-learn 核心算法模块
第 9 章 有监督学习:分类算法
第 10 章 有监督学习:回归算法
第 11 章 无监督学习:聚类与密度算法
第 12 章 半监督学习与其他常用算法
第八部分:性能优化与问题解决
第 18 章 Scikit-learn 性能优化
第 19 章 Scikit-learn 常见问题与解决方案

10.2 非线性回归算法

Scikit-learn非线性回归算法详解:决策树回归、k近邻回归与支持向量回归

Scikit-learn 中文教程

本教程全面介绍Scikit-learn中非线性回归算法,包括决策树回归器的剪枝技巧、k近邻回归器的k值与距离度量优化以及支持向量回归的核函数应用,适合机器学习初学者入门。

推荐工具
PyCharm专业版开发必备

功能强大的Python IDE,提供智能代码补全、代码分析、调试和测试工具,提高Python开发效率。特别适合处理列表等数据结构的开发工作。

了解更多

非线性回归算法在Scikit-learn中的实现

介绍

非线性回归是机器学习中的重要概念,用于处理特征与目标变量之间非线性关系。与线性回归不同,非线性回归允许模型更灵活地拟合复杂数据模式。Scikit-learn提供了多种算法来实现非线性回归,本教程将重点介绍三种常用方法:决策树回归器(DecisionTreeRegressor)、k近邻回归器(KNeighborsRegressor)和支持向量回归(SVR)。这些算法简单易用,适合新人学习,并能有效处理各种非线性回归任务。

决策树回归器(DecisionTreeRegressor)

决策树回归器通过递归分割数据来建模,自动捕获非线性关系和特征交互。它的优势在于直观易懂,无需对数据做线性假设。

  • 处理非线性关系:决策树将数据划分为多个区域,每个区域拟合一个简单模型(如常数),从而组合成复杂非线性函数。这使得它能适应各种数据分布,例如曲线或不规则模式。
  • 剪枝:决策树容易过拟合(即对训练数据过于复杂,泛化能力差)。剪枝技术通过限制树的深度或节点分割条件来防止过拟合。在Scikit-learn的DecisionTreeRegressor中,常用参数包括:
    • max_depth:树的最大深度,限制分割次数。
    • min_samples_split:节点分裂所需的最小样本数。
    • min_samples_leaf:叶节点所需的最小样本数。 调整这些参数可以帮助平衡模型复杂度和性能。

代码示例

from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import numpy as np

# 假设已有特征X和目标y数据
# 生成示例数据(用于演示)
np.random.seed(42)
X = np.random.rand(100, 1) * 10  # 100个样本,1个特征
y = np.sin(X).ravel() + np.random.normal(0, 0.1, 100)  # 非线性关系加噪声

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 初始化决策树回归器,设置最大深度进行剪枝
model = DecisionTreeRegressor(max_depth=3, random_state=42)
model.fit(X_train, y_train)

# 预测和评估
y_pred = model.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
print(f"决策树回归的均方误差: {mse:.4f}")

注意事项:决策树回归器可能对数据噪声敏感,建议通过交叉验证调优参数并可视化树结构以检查过拟合。

k近邻回归器(KNeighborsRegressor)

k近邻回归是一种基于实例的学习方法,通过查找k个最近邻点的平均值进行预测。它简单且适用于局部数据模式。

  • k值优化:k值(邻居数)是核心参数。k过小可能导致过拟合(受噪声影响),k过大则可能忽略局部结构,导致欠拟合。通常使用交叉验证选择最佳k值,例如在1到20范围内测试。
  • 距离度量优化:距离度量决定如何计算邻居的远近。Scikit-learn支持多种度量,如:
    • euclidean(欧几里得距离,默认):常用标准距离。
    • manhattan(曼哈顿距离):适用于高维或稀疏数据。
    • minkowski(闵可夫斯基距离):可调整参数p以切换距离类型。 根据数据特性选择度量可以提升模型精度。

代码示例

from sklearn.neighbors import KNeighborsRegressor
from sklearn.model_selection import GridSearchCV

# 假设使用上面示例的X_train, y_train等数据
# 初始化k近邻回归器
model = KNeighborsRegressor()

# 定义参数网格搜索
param_grid = {
    'n_neighbors': [3, 5, 10],  # 测试不同k值
    'metric': ['euclidean', 'manhattan']  # 测试距离度量
}
grid_search = GridSearchCV(model, param_grid, cv=5, scoring='neg_mean_squared_error')
grid_search.fit(X_train, y_train)

# 输出最佳参数和评估
print(f"最佳参数: {grid_search.best_params_}")
best_model = grid_search.best_estimator_
y_pred = best_model.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
print(f"k近邻回归的均方误差: {mse:.4f}")

注意事项:k近邻回归对数据尺度敏感,建议先进行特征标准化(如使用StandardScaler),并在小数据集上效率较高。

支持向量回归(SVR)

支持向量回归基于支持向量机框架,通过核函数将数据映射到高维空间处理非线性关系。它泛化能力强,适合复杂数据。

  • 核函数适配:核函数是SVR处理非线性的关键。常见核函数包括:
    • linear:线性核,适用于线性关系。
    • poly(多项式核):通过多项式阶数控制非线性度。
    • rbf(径向基函数核,默认):最常用非线性核,通过参数gamma调节核宽度。
    • sigmoid:S形核,适用于某些特定分布。 选择核函数和调优参数(如C正则化参数和gamma)对性能至关重要。

代码示例

from sklearn.svm import SVR
from sklearn.preprocessing import StandardScaler

# 标准化特征,SVR对数据尺度敏感
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# 初始化SVR,使用RBF核函数
model = SVR(kernel='rbf', C=1.0, gamma='scale')  # gamma='scale'基于数据自动调整
model.fit(X_train_scaled, y_train)

# 预测和评估
y_pred = model.predict(X_test_scaled)
mse = mean_squared_error(y_test, y_pred)
print(f"SVR的均方误差: {mse:.4f}")

注意事项:SVR可能计算开销较大,尤其在大数据集上;通过缩放特征和调优参数可以提高效率。

总结

  • 决策树回归器:易于解释和可视化,自动处理非线性,但可能过拟合;使用剪枝参数优化。
  • k近邻回归器:简单直观,适用于局部模式;优化k值和距离度量以提升准确性。
  • 支持向量回归:泛化性能好,适合复杂非线性关系;核函数选择和参数调整是关键。

实践建议:对于新项目,建议从决策树开始,因为它容易调优和解释。如果数据有强局部特性,尝试k近邻回归;对于复杂非线性问题,SVR可能更有效。始终使用交叉验证和评估指标(如均方误差)比较不同算法,并根据数据规模和应用场景做选择。

延伸学习

  • 探索Scikit-learn文档了解更多参数和高级用法。
  • 尝试组合这些算法(如集成方法)以进一步提升性能。
  • 实践中,数据预处理(如缺失值处理)对非线性回归同样重要。

本教程帮助您入门Scikit-learn非线性回归,祝您学习顺利!

开发工具推荐
Python开发者工具包

包含虚拟环境管理、代码格式化、依赖管理、测试框架等Python开发全流程工具,提高开发效率。特别适合处理复杂数据结构和算法。

获取工具包