9.2 非线性分类算法
Scikit-learn 非线性分类算法教程:SVM, 决策树, kNN, 朴素贝叶斯详解
本教程面向Scikit-learn新手,全面解析非线性分类算法,包括支持向量机(SVM)及其核函数选择、决策树的特征重要性和剪枝策略、k近邻算法的k值和距离度量、朴素贝叶斯变种及文本分类应用。
非线性分类算法教程
介绍
在机器学习中,线性分类器(如逻辑回归)假设数据可以用直线或超平面分离。然而,实际数据往往呈现复杂非线性关系,这时需要非线性分类算法。Scikit-learn 提供了多种高效的非线性分类器,本教程将重点讲解以下四种:支持向量机(SVM)、决策树分类器、k近邻算法和朴素贝叶斯分类器。每种算法都有其独特的优势,适合不同场景。让我们从基本概念入手,逐步深入,并通过代码示例帮助新手快速上手。
什么是非线性分类?
非线性分类指数据无法直接用线性边界划分时,使用曲线或复杂边界进行分类。例如,一个圆形或螺旋形的数据集,需要算法通过映射或递归分割来处理。Scikit-learn 中的非线性分类器通过核技巧、树结构或邻居投票等方法实现这一点。
1. 支持向量机(SVM)
支持向量机(SVM)是一种强大的分类算法,通过核技巧处理非线性数据。在 Scikit-learn 中,SVC 用于分类,SVR 用于回归(本教程聚焦分类)。
核函数选择:从线性到非线性
核函数是 SVM 处理非线性的关键,它通过将数据映射到高维空间来寻找最优分离超平面。Scikit-learn 支持多种核函数:
- 线性核(linear):适用于线性可分数据,计算简单。使用
kernel='linear'。 - 高斯 RBF 核(rbf):最常用的核函数,通过高斯函数映射,能处理复杂非线性模式。使用
kernel='rbf',并可通过gamma参数控制核的宽度。 - 多项式核(poly):基于多项式函数映射,适合某些特定模式。使用
kernel='poly',并指定degree参数控制多项式次数。
选择建议:对于大多数非线性问题,RBF 核是首选;对于简单线性数据,使用线性核;多项式核可尝试于需要显式多项式关系的场景。
代码示例:使用 RBF 核的 SVM
from sklearn.svm import SVC
from sklearn.datasets import make_circles # 生成非线性数据
from sklearn.model_selection import train_test_split
# 生成圆形数据,模拟非线性场景
X, y = make_circles(n_samples=100, noise=0.1, factor=0.5, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 初始化 SVM 模型,使用 RBF 核
svm_model = SVC(kernel='rbf', gamma='auto', random_state=42)
svm_model.fit(X_train, y_train)
# 评估模型
accuracy = svm_model.score(X_test, y_test)
print(f"模型准确率: {accuracy:.2f}")
解释:gamma='auto' 自动设置核宽度,通常适用于简单实验。在实际应用中,建议使用 GridSearchCV 调优 gamma 和 C 参数。
2. 决策树分类器(DecisionTreeClassifier)
决策树是一种直观的分类算法,通过递归分割数据构建树结构,天生支持非线性分类,易于解释。
特征重要性
决策树训练后,可以评估每个特征的重要性,这有助于特征选择。在 Scikit-learn 中,通过 feature_importances_ 属性获取。特征重要性表示该特征在分裂节点时的贡献度。
剪枝策略:防止过拟合
决策树容易过拟合(即在训练数据上表现太好,但在测试数据上差),因此需要剪枝策略来控制复杂度。常用参数包括:
max_depth:限制树的最大深度,防止过度生长。min_samples_split:节点分裂所需的最小样本数。min_samples_leaf:叶节点所需的最小样本数。
通过调整这些参数,可以平衡模型复杂度和泛化能力。
代码示例:决策树分类
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
# 加载鸢尾花数据集
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 初始化决策树模型,设置剪枝参数
tree_model = DecisionTreeClassifier(max_depth=3, min_samples_split=5, random_state=42)
tree_model.fit(X_train, y_train)
# 查看特征重要性
print("特征重要性:", tree_model.feature_importances_)
# 评估模型
accuracy = tree_model.score(X_test, y_test)
print(f"模型准确率: {accuracy:.2f}")
提示:可视化决策树(使用 plot_tree)可以帮助理解树的结构和分裂点。
3. k 近邻算法(KNeighborsClassifier)
k近邻(kNN)是一种懒惰学习算法,基于邻居投票进行分类。它不假设数据分布,适合非线性数据,但对噪声敏感。
k 值选择:平衡过拟合和欠拟合
k 值是邻居数量,影响模型性能:
- 小 k 值(如 k=1):模型更复杂,可能过拟合(对噪声敏感)。
- 大 k 值(如 k=10):模型更平滑,可能欠拟合(忽略局部模式)。
选择方法:使用交叉验证(如 GridSearchCV)选择最优 k,避免手动猜测。
距离度量:如何衡量邻居
距离度量定义数据点之间的相似性。Scikit-learn 支持多种度量:
- 欧几里得距离(euclidean):默认值,适合连续特征。
- 曼哈顿距离(manhattan):适合有噪声的数据或高维空间。
通过 metric 参数指定距离度量。
代码示例:kNN 分类
from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import make_classification
# 生成模拟非线性数据
X, y = make_classification(n_samples=100, n_features=2, n_informative=2, n_redundant=0, n_clusters_per_class=1, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 初始化 kNN 模型,设置 k=5 和使用欧几里得距离
knn_model = KNeighborsClassifier(n_neighbors=5, metric='euclidean')
knn_model.fit(X_train, y_train)
# 评估模型
accuracy = knn_model.score(X_test, y_test)
print(f"模型准确率: {accuracy:.2f}")
实践建议:在大型数据集上,kNN 计算开销大,考虑使用降维或近似算法(如 BallTree)加速。
4. 朴素贝叶斯(NaiveBayes)
朴素贝叶斯基于贝叶斯定理,假设特征条件独立,虽简单但高效,特别适合文本分类等高维数据。
变种及其应用
Scikit-learn 提供多种变种,适应不同数据类型:
- 高斯朴素贝叶斯(GaussianNB):假设特征服从高斯分布,适用于连续数据,如数值特征。
- 伯努利朴素贝叶斯(BernoulliNB):假设特征为二值(0/1),适合布尔型数据,如文本中的词出现与否。
- 多项式朴素贝叶斯(MultinomialNB):假设特征为计数,常用于文本分类(如词频向量)。
文本分类应用
朴素贝叶斯是文本分类的经典算法。结合 TF-IDF 向量化,可以高效处理文本数据。
代码示例:使用多项式朴素贝叶斯进行文本分类
from sklearn.naive_bayes import MultinomialNB
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.datasets import fetch_20newsgroups # 示例文本数据集
# 加载新闻组数据集,选取两个类别
categories = ['alt.atheism', 'soc.religion.christian']
newsgroups_train = fetch_20newsgroups(subset='train', categories=categories, shuffle=True, random_state=42)
X_train, y_train = newsgroups_train.data, newsgroups_train.target
# 向量化文本:将文本转换为 TF-IDF 特征
vectorizer = TfidfVectorizer(stop_words='english', max_features=1000) # 限制特征数以加速
X_train_vec = vectorizer.fit_transform(X_train)
# 初始化多项式朴素贝叶斯模型
nb_model = MultinomialNB()
nb_model.fit(X_train_vec, y_train)
# 预测新文本示例
text = ["This is a sample text about religion and belief."]
text_vec = vectorizer.transform(text)
prediction = nb_model.predict(text_vec)
print(f"预测类别: {prediction[0]}") # 0 或 1,对应类别
注意:特征独立性假设在现实中可能不成立,但朴素贝叶斯仍表现良好,尤其在文本分类中。
总结与比较
下表简要比较这些算法,帮助选择合适的非线性分类器:
| 算法 | 优势 | 缺点 | 适用场景 |
|---|---|---|---|
| SVM | 高维数据处理能力强,泛化好 | 计算开销大,参数调优复杂 | 图像分类、文本分类(高维) |
| 决策树 | 易于解释,无需数据预处理 | 容易过拟合,对噪声敏感 | 客户细分、医学诊断 |
| k近邻 | 简单直观,无需训练假设 | 对噪声敏感,计算慢(大数据) | 推荐系统、模式识别 |
| 朴素贝叶斯 | 快速高效,适合高维数据 | 特征独立性假设可能不成立 | 文本分类、垃圾邮件过滤 |
下一步建议
- 实践:在 Scikit-learn 官方数据集(如
load_digits或load_wine)上尝试这些算法。 - 调优:使用
GridSearchCV或RandomizedSearchCV自动搜索最佳参数。 - 评估:学习更多评估指标(如精确率、召回率、F1 分数)以全面评估模型。
通过这些内容,新手可以理解非线性分类的基本概念,并掌握 Scikit-learn 中关键算法的应用。记住,机器学习是实践出真知,多动手编码才能深化理解。