700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > 支持向量机SVM Iris数据集 分类预测

支持向量机SVM Iris数据集 分类预测

时间:2020-08-21 11:51:02

相关推荐

支持向量机SVM  Iris数据集  分类预测

目录

支持向量机对iris数据集进行分类预测

1. 基础概念

2. 实验步骤与分析

2.1 数据理解

2.2 数据读入

2.3 训练集和测试集划分

2.4 支持向量机

2.5 预测

2.6 分析

2.7 调整参数,提高预测准确率

3. 总结

支持向量机对iris数据集进行分类预测

1. 基础概念

SVM的主要思想是:建立一个超平面作为决策平面,使得正例和反例之间的隔离边缘被最大化。SVM也是结构风险最小化方法的近似实现。

2. 实验步骤与分析

2.1 数据理解

由Fisher在1936年整理,包含4个特征(Sepal.Length(花萼长度)、Sepal.Width(花萼宽度)、Petal.Length(花瓣长度)、Petal.Width(花瓣宽度)),特征值都为正浮点数,单位为厘米。目标值为鸢尾花的分类(Iris Setosa(山鸢尾)、Iris Versicolour(杂色鸢尾),Iris Virginica(维吉尼亚鸢尾))。

2.2 数据读入

from sklearn import datasetsiris = datasets.load_iris() #加载数据集

2.3 训练集和测试集划分

from sklearn.model_selection import train_test_splitX = iris.data[:,:2]y = iris.target #标签X_train, X_test, y_train, y_test = train_test_split(X,y,test_size = 0.3, random_state = 0)

2.4 支持向量机

lin_svc = svm.SVC(kernel='linear').fit(X_train,y_train) # 核函数kernel为线性核函数rbf_svc = svm.SVC(kernel='rbf').fit(X_train,y_train) # kernel为径向基核函数poly_svc = svm.SVC(kernel='poly',degree=3).fit(X_train,y_train) #kernel为多项式核函数

图1. SVM常用核函数

2.5 预测

lin_svc_pre = lin_svc.predict(X_test) # linear核函数的svm进行预测rbf_svc_pre = rbf_svc.predict(X_test) # 径向基rbf核函数 预测poly_svc_pre = poly_svc.predict(X_test) #多项式核函数预测

2.6 分析

# score函数,根据给定数据与标签返回正确率的均值acc_lin_svc = lin_svc.score(X_test,y_test) acc_rbf_svc = rbf_svc.score(X_test,y_test)acc_poly_svc = poly_svc.score(X_test,y_test)

图2.花萼的长宽和目标的关系图

图2. 花萼的长宽和目标,在 #sepal 和 target 的关系 代码块中画出。

图3.准确率acc_lin/rbf/poly_svc和三种不同kernel的预测值

图3. 准确率acc_lin_svc = 0.8 #表示以kernel为线性函数时,准确率为0.8

acc_lin_predicted: #表示预测值

其余类似.

图4.三种核函数对应的SVM分类结果图

图4,将分类结果可视化,不过这是超平面分类后投影的结果。分类边界参考图2.目标关系图,便可以了解分类的情况。可以发现在训练数据集X = iris.data[:,:2]时,使用svm的不同核函数预测准确性都不高。

2.7 调整参数,提高预测准确率

调整训练和测试数据集的大小,只需要在训练集和测试集那里进行调整使用径向基函数rbf作为核函数来进行对比,svm.SVC(C=1.0, kernel = 'rbf', gamma = 'auto') #output#X = iris.data[:,:4]#SVM模型准确率: 0.9777777777777777#X = iris.data[:,:3]#SVM模型准确率: 0.9555555555555556#X = iris.data[:,:2]#SVM模型准确率: 0.8

很明显,数据集的增大提供了更精准的预测。

3. 总结

这里列出的都是简单库函数,也就是调包就可以完成的,但是真正要理解还是得从数学原理推导,详见《统计学习方法》。有时间我也会把推导好的编辑好发出来,现在是初学入门阶段,自己写一些实验有一些感性的认识,不过也不能缺少理性的认识,就是从数学原理上的推导。

完整代码:

from sklearn import datasetsfrom sklearn import svmfrom sklearn.model_selection import train_test_splitimport numpy as np#一、iris数据获取iris = datasets.load_iris() #加载数据集print(type(iris),dir(iris)) #打印数据集属性y = iris.target#可视化#sepal 和 target 的关系X = iris.data[:,:2] #取所有行的前两位元素plt.scatter(X[y==0,0],X[y==0,1],color = 'r',marker='o') #X中y==0的,索引为0的元素(共50个) ,即第一个元素为横坐标,索引为1的为plt.scatter(X[y==1,0],X[y==1,1],color = 'b',marker='*')plt.scatter(X[y==2,0],X[y==2,1],color = 'g',marker='+')plt.title('the relationship between sepal and target classes')plt.xlabel('sepal length')plt.ylabel('sepal width')plt.show()#二、数据预处理#pass#三、模型的训练X = iris.data[:,:2]y = iris.target #标签X_train, X_test, y_train, y_test = train_test_split(X,y,test_size = 0.3, random_state = 0)lin_svc = svm.SVC(decision_function_shape='ovo', kernel = 'rbf', gamma = 'auto') #ovo,one versus one 一对一的分类器,这时对K个类别需要构建K * (K - 1) / 2个分类器lin_svc.fit(X_train,y_train)#ovr:one versus rest,一对其他,这时对K个类别只需要构建K个分类器。rbf_svc = svm.SVC(decision_function_shape='ovo',kernel='linear')rbf_svc.fit(X_train, y_train)poly_svc = svm.SVC(decision_function_shape = 'ovo', kernel = 'poly',degree = 3)poly_svc.fit(X_train, y_train)#四、模型的评估#分类结果可视化#可视化分类结果# the step of the gridh = .02 # to create the grid , so that we can plot the images on itx_min, x_max = X_train[:, 0].min() - 1, X_train[:, 0].max() + 1y_min, y_max = X_train[:, 1].min() - 1, X_train[:, 1].max() + 1xx, yy = np.meshgrid(np.arange(x_min, x_max, h),np.arange(y_min, y_max, h))# the title of the graphtitles = ['LinearSVC (linear kernel)','SVC with RBF kernel','SVC with polynomial (degree 3) kernel']for i, clf in enumerate((lin_svc, rbf_svc, poly_svc)):# to plot the edge of different classes# to create a 2*2 grid , and set the i image as current imageplt.subplot(2, 2, i + 1) # set the margin between different sub-plotplt.subplots_adjust(wspace=0.4, hspace=0.4)# SVM input :xx and yy output: an arrayZ = clf.predict(np.c_[xx.ravel(), yy.ravel()]) # to plot the resultZ = Z.reshape(xx.shape) #(220, 280)plt.contourf(xx, yy, Z, cmap=plt.cm.Paired, alpha=0.8) plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Paired)plt.xlabel('Sepal length')plt.ylabel('Sepal width')plt.xlim(xx.min(), xx.max())plt.ylim(yy.min(), yy.max())plt.xticks(())plt.yticks(())plt.title(titles[i])plt.show()acc_lin_svc = lin_svc.score(X_test,y_test) #根据给定数据与标签返回正确率的均值acc_rbf_svc = rbf_svc.score(X_test,y_test)acc_poly_svc = poly_svc.score(X_test,y_test)lin_svc_pre = lin_svc.predict(X_test)rbf_svc_pre = rbf_svc.predict(X_test)poly_svc_pre = poly_svc.predict(X_test)print('acc_lin_svc: ',acc_lin_svc)print('acc_lin_predicted: ',lin_svc_pre)print('acc_rbf_svc: ',acc_rbf_svc)print('acc_rbf_predicted: ',rbf_svc_pre)print('acc_poly_svc: ',acc_poly_svc)print('acc_poly_predicted: ',poly_svc_pre)#五、模型的优化#pass#六、模型持久化#pass

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。