社区所有版块导航
Python
python开源   Django   Python   DjangoApp   pycharm  
DATA
docker   Elasticsearch  
aigc
aigc   chatgpt  
WEB开发
linux   MongoDB   Redis   DATABASE   NGINX   其他Web框架   web工具   zookeeper   tornado   NoSql   Bootstrap   js   peewee   Git   bottle   IE   MQ   Jquery  
机器学习
机器学习算法  
Python88.com
反馈   公告   社区推广  
产品
短视频  
印度
印度  
Py学习  »  机器学习算法

机器学习新手必看:一文搞定SVM算法

景略集智 • 5 年前 • 576 次点击  

机器学习新手必看:一文搞定SVM算法

掌握机器学习算法根本不难。大多数初学者从回归开始学习,虽然学习和使用它很简单,但是这能解决我们的目的吗?当然不能!因为你可以做的不仅仅是回归问题!

我们可以将机器学习算法看作是装满斧头、剑、刀片、弓、匕首等的军械库。你有各种工具,但你应该学会在正确的时间综合使用它们。

作为一个类比,我们将“回归”看作一把能够有效切割数据的剑,但它无法处理高度复杂的数据。相反,支持向量机(SVM)算法就像一把锋利的刀 ——它可以在较小的数据集上工作,而在这些小数据集上面,在搭建模型时,它的性能却强大的多。

本文我们一起学习 SVM 机器学习算法从初级到高级的相关知识。示例+代码,一文搞定支持向量机算法。

目录

  • 什么是 SVM 算法?

  • SVM 算法是如何工作的 ?

  • 如何在 Python 和 R 语言中实现 SVM?

  • 如何调整 SVM 的参数?

  • SVM 的优点和缺点

什么是 SVM 算法?

支持向量机(SVM) 是一个监督学习算法,既可以用于分类问题也可以用于回归问题。但是,SVM算法还是主要用在分类问题中。在 SVM 算法中,我们将数据绘制在 n 维空间中(n 代表数据的特征数),每个特征数的值是特定坐标的值。然后我们通过查找可以将数据分成两类的超平面(请参照下图)。

支持向量指的是观察的样本在 n 维空间中的坐标,SVM 是将样本分成两类的最佳超平面。

SVM 是如何工作的?

上面的简介告诉我们 SVM 是通过超平面将两类样本分开。现在我们的主要问题是:“如何将两类样本分开”。别担心,这个问题并不像你想象的那么难!

首先我们来看:

找出正确的超平面(场景-1):在这里,我们有三个超平面(A,B 和 C)。 现在,找出正确的超平面并用星和圆区分。

对于识别正确的超平面,你需要记住一条经验法则:“选择更好地隔离两个类的超平面”。在我们这里的例子中,超平面“B”非常出色地完成了这项工作。

找出正确的超平面(场景-2):在这里,我们有三个超平面(A,B 和 C),并且所有的类都被进行了很好的隔离。现在,我们如何确定正确的超平面?

在这里,最大化最近数据点(任一类)和超平面之间的距离将有助于我们决定正确的超平面。这个距离被称为边距。请看下图:

在上面,你可以看到,与 A 和 B 相比,超平面 C 的边距很大。因此,我们将这个正确的超平面命名为 C.另一个选择边距更大超平面的原因是鲁棒性更强。如果我们选择一个具有较低边距的超平面,那么很有可能会错误分类。

找出正确的超平面(场景-3):提示:使用上一节讨论的规则来找出正确的超平面。

可能有些朋友会选择超平面 B,因为它比 A 具有更高的边距。但是,这里有一个问题,SVM 算法在选择超平面方面,会优先考虑正确分类而非更大的边距。这里,超平面 B 有分类错误,而 A 已经正确分类。因此,A 才是正确的超平面。

我们能区分两个类别吗?(情景-4):下图中,我无法用一条直线将这两个类别分开,因为其中一个星星位于圈圈的领域之内,也就是离散值。

我之前提过,如果一个星跑到了另一边,那么它就是星类别的离散值。SVM 的一个特征就是会忽略离散值并找到具有最大边距的超平面。因此,我们可以说,SVM 对于离散值具有鲁棒性。

找到超平面以分离不同类(场景-5):在下面的场景中,我们没法让线性超平面区分这两个类,那么 SVM 如何对这两个类进行分类?到现在为止,我们只看到了线性超平面。

SVM 可以很方便的解决这个问题!它可以通过引入附加特征解决。在这里,我们增加一个新的 z=x^2+y^2 特征。现在,我们绘制轴 x 和 z 上的数据点:

在上面的情节中,要考虑的要点是:

  • z 的所有值总是正值,因为它是是 x 和 y 的平方和。

  • 在原始图中,红色点离 x 和 y 轴更近,z 值相对较小,星相对较远会有较大的 z 值。

在 SVM 中,这两个类之间很容易有一个线性超平面。但是,另一个亟待解决的问题是,我们是否需要手动添加此功能才能拥有超平面?不需要,SVM 有一个叫核函数的功能。核函数具有将低维数据转化成高维数据的作用。它将不可分离的问题转换成可分离的问题,这些函数被称为内核。它主要用于非线性分离问题。简而言之,它执行一些非常复杂的数据转换,然后根据您定义的标签或输出找出分离数据的过程。

当我们在原始输入空间中查看超平面时,它看起来像一个圆圈:

现在,我们来看看在数据科学挑战中应用 SVM 算法的方法。

如何在 Python 和 R 语言中实现 SVM?

在 Python 中,scikit-learn 是一个在机器学习算法中被广泛使用的库,SVM 也可以在 scikit-learn 库中使用,并遵循相同的结构(导入库,对象创建,拟合模型和预测)。让我们来看下面的代码:

#Import Library
from sklearn import svm



    
#Assumed you have, X (predictor) and Y (target) for training data set and x_test(predictor) of test_dataset
# Create SVM classification object 
model = svm.svc(kernel='linear', c=1, gamma=1) 
# there is various option associated with it, like changing kernel, gamma and C value. Will discuss more # about it in next section.Train the model using the training sets and check score
model.fit(X, y)
model.score(X, y)
#Predict Output
predicted= model.predict(x_test)

使用 R 语言中的 e1071 软件包可以轻松创建 SVM。它具有辅助函数以及 Naive Bayes 分类器的代码。在 R 和 Python 中创建 SVM 遵循类似的方法,现在让我们来看看下面的代码

#Import Library
require(e1071) #Contains the SVM 
Train <- read.csv(file.choose())
Test <- read.csv(file.choose())
# there are various options associated with SVM training; like changing kernel, gamma and C value.

# create model
model <- svm(Target~Predictor1+Predictor2+Predictor3,data=Train,kernel='linear',gamma=0.2,cost=100)

#Predict Output
preds <- predict(model,Test)
table(preds)

如何调整 SVM 的参数?

为机器学习算法调整参数值有效地提高了模型性能。我们来看看 SVM 可用的参数列表。

sklearn.svm.SVC(C=1.0, kernel='rbf', degree=3, gamma=0.0, coef0=0.0, shrinking=True, probability=False,tol=0.001, cache_size=200, class_weight=None, verbose=False, max_iter=-1, random_state=None)

我们来讨论其中最重要的三个参数为: “kernel”, “gamma” 和“C”。

Kernel(核函数):我们已经在上面介绍过了。这里,我们有不少适用的选择,如“linear”,“rbf”,“poly”等等(默认值是“rbf”)。这里“rbf”和“poly”对于非线性超平面很有用。我们来看看这个例子,我们已经使用线性核函数对双特征鸢尾花数据集的两个特征进行了分类。

示例:使用线性核函数

import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm, datasets
# import some data to play with
iris = datasets.load_iris()
X = iris.data[:, :2] # we only take the first two features. We could
 # avoid this ugly slicing by using a two-dim dataset
y = iris.target
# we create an instance of SVM and fit out data. We do not scale our
# data since we want to plot the support vectors
C = 1.0 # SVM regularization parameter
svc = svm.SVC(kernel='linear', C=1,gamma=0).fit(X, y)
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
h = (x_max / x_min)/100
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
 np.arange(y_min, y_max, h))

plt.subplot(1, 1, 1)
Z = svc.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt.contourf(xx, yy, Z, cmap=plt.cm.Paired, alpha=0.8)
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Paired)
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.xlim(xx.min(), xx.max())
plt.title('SVC with linear kernel')
plt.show()

示例:使用rbf核函数

将核函数类型更改为 rbf,并查看影响。

svc = svm.SVC(kernel='rbf', C=1,gamma=0).fit(X, y)

如果你有大量特征(> 1000),我建议选择线性核函数,因为它更有可能在高维空间中线性分离。此外,也可以使用 rbf,但不要忘记交叉验证其参数以避免过度拟合。

gamma:rbf 函数、Poly 函数和 S 型函数的系数。gamma 值越大,SVM 就会倾向于越准确的划分每一个训练集里的数据,这会导致泛化误差较大和过拟合问题。

例如:让我们区分一下,如果我们有不同的 gamma 值,如 0,10 或 100:

svc = svm.SVC(kernel='rbf', C=1,gamma=0).fit(X, y)

C: 错误项的惩罚参数 C。它还控制平滑决策边界和正确分类训练点之间的权衡。

我们应该始终考虑交叉验证分数,以有效组合这些参数并避免过度拟合。

在 R 语言中,SVM 的参数可以像在 Python 中一样进行调整。下面提到的是 e1071 软件包中的相应参数:

  • 内核参数可以调整为“线性”,“聚合”,“rbf”等。

  • Gamma 值可以通过设置“Gamma”参数进行调整。

  • Python 中的 C 值由 R 中的“Cost”参数进行调整。

SVM 的优点和缺点

优点:

  • 对于边界清晰的分类问题效果好;

  • 对高维分类问题效果好;

  • 当维度高于样本数的时候,SVM 较为有效;

  • 因为最终只使用训练集中的支持向量,所以节约内存

缺点:

  • 当数据量较大时,训练时间会较长;

  • 当数据集的噪音过多时,表现不好;

  • SVM 不直接提供结果的概率估计,它在计算时直接使用 5 倍交叉验证。

实践练习

找到有效的附加特征,以便得到用于分离类的超平面,如下图所示:

在下面的注释部分回答变量名称。

结尾笔记

在本文中,我们详细介绍了机器学习算法 SVM,包括了它的工作概念,Python 中的实现过程,通过调整参数来优化模型的技巧,优点和缺点,以及最后的练习题。建议你亲自使用 SVM 并通过调整参数来分析该模型的功能。

集智主站出品的 Scikit-learn 教程也曾讲过如何用 SVM 算法解决问题:点击查看


新手福利

假如你是机器学习小白,但又希望能以最高效的方式学习人工智能知识,我们这里正好有个免费学习AI的机会,让你从零到精通变身AI工程师,不了解一下?

机会传送门:戳这里!!

这可能是正点赶上AI这班车的最好机会,不要错过哦。


今天看啥 - 高品质阅读平台
本文地址:http://www.jintiankansha.me/t/BmfdlO7eoc
Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/12828
 
576 次点击