计算机系统应用教程网站

网站首页 > 技术文章 正文

浅谈支持向量机(3)

btikc 2024-11-19 02:02:57 技术文章 1 ℃ 0 评论

以下代码对比分析四种支持向量机分类预测的效果。使用到的支持向量机有:线性支持向量机、非线性支持向量机(线性核、高斯核、多项式核)

from sklearn.datasets import load_wine #读取sklearn库中的红酒数据集
from sklearn.svm import SVC,LinearSVC  #导入支持向量机中要用的库
import matplotlib.pyplot as plt        #导入绘图库
import numpy as np                     #导入计算库

#第一步:准备数据
data=load_wine()     #读入红酒数据集的数据
X=data.data[:,:2]    #取出数据的前两列特征
y=data.target        #取出数据的标签值     
x0,x1=X[:,0],X[:,1]  #把前两列特征分别赋值给x0和x1
x_min,x_max=x0.min()-1,x0.max()+1  #第一列特征的最小值减1和最大值加1
y_min,y_max=x1.min()-1,x1.max()+1  #第二列特征的最小值减1和最大值加1
xx,yy=np.meshgrid(np.arange(x_min,x_max,0.05),np.arange(y_min,y_max,0.05))
#构建两个序列,分别以x_min和y_min为起始值,差值为0.05,终值为x_max和y_max(最大只能取到max-0.05)

第二步:构建分类模型
models=(LinearSVC(),SVC(kernel="linear"), SVC(kernel='rbf',gamma=0.7),SVC(kernel='poly',degree=4))
#定义使用的线性支持向量机、核为线性的支持向量机、核为高斯的支持向量机,核为多项式的支持向量机
models=(svm.fit(X,y) for svm in models)  #使用红酒数据集来训练四个支持向量机

第三步:绘图准备
fig,ax=plt.subplots(2,2) #准备两行两列的画布,为绘制4个子图做准备
plt.subplots_adjust(wspace=0.4,hspace=0.4)#调整子图布局,wspace表示子图之间保留的宽度,hspace表示子图之间保留的高度
titles=('LinearSVC','SVC with linear kernel','SVC with RBF kernel','SVC with polynomial')
#定义四个支持向量机的名称,也就是绘制的4个子图名称
def plot_contours(ax,svm): #定义函数:绘制分类的边界
    z=svm.predict(np.c_[xx.ravel(),yy.ravel()]).reshape(xx.shape)
    #计算四个支持向量机预测的标签值,并把结果按照xx的形状重构
    return ax.contourf(xx,yy,z,cmap=plt.cm.viridis)
    #返回绘制好的分类边界,contourf函数是绘制带填充的等高线
    
第四步:开始绘图
for model,title,ax in zip(models,titles,ax.flatten()):#循环读取四个模型、标题和子图
    plot_contours(ax,model) #调用绘制分类边界函数
    ax.scatter(x0,x1,c=y,cmap=plt.cm.magma,s=20,edgecolor='k')
    #以x0和x1确定坐标点,绘制散点图
    ax.set_xlim(xx.min(),xx.max())
    ax.set_ylim(yy.min(),yy.max())
    ax.set_xlabel('Feature 0')
    ax.set_ylabel('Feature 1')
    ax.set_title(title)  #设置每张子图的标题
plt.show()

例子使用的是sklearn自带的红酒数据集,共有178条数据,每条数据13个分类。取数据集的前两列,绘制散点图和分类边界。四个支持向量机都使用了默认的正则化系数,C=1,参数kernel指定的是非线性支持向量机使用的核,高斯核中的参数gamma表示核宽度,也就是径向作用范围。增大gamma值效果如下,过于拟合当前数据,模型复杂度高,泛化能力差。

多项式支持向量机中参数degree表示多项式的最高次幂。zip函数将对象中对应的元素打包成一个个元组,然后返回这些元组组成的列表。这里就是把模型、标题、子图序号分别打包成三个元组,model循环读取模型列表,title循环读取标题列表,ax.flatten()是把ax一维化,然后逐一取出每一个要绘制的子图。

本文暂时没有评论,来添加一个吧(●'◡'●)

欢迎 发表评论:

最近发表
标签列表