700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > 支持向量机(SVM)实现MNIST手写体数字识别

支持向量机(SVM)实现MNIST手写体数字识别

时间:2024-03-23 06:49:05

相关推荐

支持向量机(SVM)实现MNIST手写体数字识别

一、SVM算法简述

支持向量机即Support Vector Machine,简称SVM。一听这个名字,就有眩晕的感觉。支持(Support)、向量(Vector)、机器(Machine),这三个毫无关联的词,硬生生地凑在了一起。从修辞的角度,这个合成词最终落脚到”Machine”上,还以为是一种牛X的机器呢?实际上,它是一种算法,是效果最好的分类算法之一。

SVM是最大间隔分类器,它能很好地处理线性可分的问题,并可推广到非线性问题。实际使用的时候,还需要考虑噪音的问题。

本文只是一篇学习笔记,主要参考了July、pluskid等人相关文章。将要点记录下来,促进自己的进步。

SVM是最大间隔分类器

既然SVM是用来分类的,咱就举个简单的例子,看看这个SVM有啥特点。如下图所示,有一个二维平面,平面上有两种不同的数据,分别用圈和叉表示。由于这些数据是线性可分的,可以用一条直线将这两个数据分开,这样的直线可以有无数条。

绿线、粉红线、黑线都能将两类区分开。但是那种更好呢?感觉上黑线似乎更好些。粉红线和绿线都离样本太近。要是样本或分界线稍稍有些扰动,分类就可能出错。黑线好就好在离两类都有一个安全间隔(蓝线与黑线间的间隔),即使有些扰动,分类还是准确的。这个安全间隔,也就是“Margin”,当然我们觉得间隔越大分类越准确。

这种分类思想该作何理解呢,他和逻辑回归的分类有何区别呢?

当用逻辑回归的思想来处理分类问题时(将数据分成正负两类:正类y=1,负类y=0)。逻辑回归函数反映的是数据是正类的概率,当这个概率大于0.5时,预测这个数据是正类,反之,小于0.5时,预测这个数据是负类。它优化的目标是预测出错的概率越小越好。可以参看这里

SVM则不同,它要找出一条离两类都有一定安全间隔的分界线(专业点叫超平面)。优化的目标就是安全间隔越大越好。

因此,SVM也被叫做最大间隔分类器。

线性可分的情况

SVM是通过间隔来分类。我们怎么来定量地表达呢?先来看看线性可分的情况,分类函数

f(x)=wTx+bf(x)=wTx+b

xx是特征向量,ww是与特征向量维数相同的向量,也叫权重向量,bb是一个实数,也叫偏置。当f(x)=0f(x)=0时,表达的就是SVM的分类边界,也就是超平面。SVM分成的类y可以为1或-1(注意,与逻辑回归不同,不是1和0)。f(x)f(x)大于0的点对应y=1的数据,f(x)f(x)小于0的点对应y=-1的数据。那我们关心的间隔怎么表达?

先来看看函数间隔,用γ^γ^表示:γ^=y(wTx+b)=yf(x)γ^=y(wTx+b)=yf(x)。|f(x)||f(x)|值越大,也就是yf(x)yf(x)越大,数据点离超平面越远,我们越能确信这个数据属于哪一类别,这是最直观的认识。

那这个是不是就完美表达了我们想要的间隔呢?看看这种情况,固定超平面,当ww,bb同时乘以2,这个间隔就扩大了两倍。那怎么表达不受参数缩放的变化影响的间距呢?老老实实来画个图看看咯。

xx是超平面外的一点,它离超平面的距离是γγ,显然ww是超平面的法向量,x0x0是xx在超平面的投影。则x=x0+γw||w||x=x0+γw||w||,其中||w||||w||是范数,用初等数学来理解就是向量的长度,也叫向量的模。因为在超平面上,f(x0)=0f(x0)=0,等式两边乘以wTwT,再加上一个bb,化简可得γ=wTx+b||w||=f(x)||w||γ=wTx+b||w||=f(x)||w||。注意这个γγ是可正可负的,为了得到绝对值,乘以一个对应的类别y,即可得出几何间隔(用γ~γ~表示)的定义:

γ~=yf(x)∥w∥=γ^∥w∥γ~=yf(x)‖w‖=γ^‖w‖

这个γ~γ~是不受参数缩放影响的。于是,我们的SVM的目标函数就是

maxγ~maxγ~

,当然它得满足一些条件,根据margin的含义

yi(wTxi+b)=γ^i≥γ^,i=1,…,nyi(wTxi+b)=γ^i≥γ^,i=1,…,n

其中γ^=γ~∥w∥γ^=γ~‖w‖.之前说过,即使超平面固定,γ~γ~的值也会随着||w||||w||的变化而变化。由于我们的目标就是要确定超平面,因此可以将无关的变量固定下来,固定的方式有两种:一是固定||w||||w||,当我们找到最优的γ~γ~时γ^γ^也就随之而固定;二是反过来固定γ^γ^,此时||w||||w||也可以根据最优的γ~γ~得到。出于方便推导和优化的目的,我们选第二种,令γ^=1γ^=1,则我们的目标函数化为:

max1∥w∥,s.t.,yi(wTxi+b)≥1,i=1,…,nmax1‖w‖,s.t.,yi(wTxi+b)≥1,i=1,…,n

支持向量作何理解

说了这么多,也没有说到Support vector(支持向量),仔细观看下图:

有两个支撑着中间的分界超平面的超平面,称为gap。它们到分界超平面的距离相等。这两个gap上必定会有一些数据点。如果没有,我们就可以进一步扩大margin了,那就不是最大的margin了。这些经过gap的数据点,就是支持向量(Support Vector)(它们支持了中间的超平面)。很显然,只有支持向量才决定超平面,其他的数据点不影响超平面的确定。

这是一个十分优良的特性。假设有100万个数据点,支持向量100个,我们实际上只需要用这100个支持向量进行计算!!!这将大大提高存储和计算的性能。

线性SVM的求解

考虑目标函数:max1∥w∥,s.t.,yi(wTxi+b)≥1,i=1,…,nmax1‖w‖,s.t.,yi(wTxi+b)≥1,i=1,…,n

由于求的1||w||1||w||最大值相当于求12∥w∥212‖w‖2的最小值,所以上述目标函数等价于:

min12∥w∥2s.t.,yi(wTxi+b)≥1,i=1,…,nmin12‖w‖2s.t.,yi(wTxi+b)≥1,i=1,…,n

1/2是方便求导时约去。这时目标函数是二次的,约束条件是线性的,所以它是一个凸二次规划问题。这个问题可以用现成的QP(Quadratic Programming)的优化包进行求解。但是这个问题还有些特殊的结构,可以通过Lagrange Duality变换到对偶变量的优化问题。通常求解对偶变量优化问题的方法比QP优化包高效得多,而且推导过程中,可以很方便地引出核函数。

简单地说,通过给每个约束条件加上一个拉格朗日乘子,我们可以将它们融和到目标函数里去,拉格朗日函数如下:

L(w,b,α)=12∥w∥2−∑i=1nαi(yi(wTxi+b)−1)L(w,b,α)=12‖w‖2−∑i=1nαi(yi(wTxi+b)−1)

这里还需要说明一点。当xixi不是支持向量时,αi=0αi=0;当xixi是支持向量时,yi(wTxi+b)−1=0yi(wTxi+b)−1=0。这个其实很好理解,因为超平面由支持向量决定,非支持向量不会影响到参数w。

这里省略掉推导的过程,这个函数经过变换,并且满足KKT条件。会得出如下结论:

w=∑i=1nαiyixiw=∑i=1nαiyixi

∑i=1nαiy=0∑i=1nαiy=0

求解的问题可以变换为

maxα∑i=1nαi−12∑i,j=1nαiαjyiyjxTixjs.t.αi≥0,i=1,…,nmaxα∑i=1nαi−12∑i,j=1nαiαjyiyjxiTxjs.t.αi≥0,i=1,…,n

上式可以通过SMO算法求出拉格朗日乘子αα,进而求出ww,通过

b=−maxyi=−1wTxi+minyj=1wTxj2b=−maxyi=−1wTxi+minyj=1wTxj2

,求出b

处理非线性问题

通过上面的讨论,我们表达了SVM的目标函数,并给出了求解的方法。于是SVM就讲完了,可以休息了?细心的读者一定发现,上面是在线性可分的前提下展开讨论的。线性不可分的时候怎么办?

那可不可以将非线性问题转换成线性问题呢?先来看个例子。

二维平面上,这是一个典型的线性不可分的问题。但我们增加一些特征,将数据点映射到高维空间,他就变成了线性可分的点集了。如下图:

事实上,将任何线性不可分的点集映射到高维空间(甚至可以到无穷维空间),总能变成线性可分的情况。只不过维数越高,计算量越大。维数大到无穷的时候,就是一场灾难了。

现在我们还是从数学上梳理一下这个映射的过程。

根据w=∑ni=1αiyixiw=∑i=1nαiyixi,分类函数可写成:

f(x)=(∑i=1nαiyixi)Tx+b=∑i=1nαiyixTix+b=∑i=1nαiyi〈xi,x〉+bf(x)=(∑i=1nαiyixi)Tx+b=∑i=1nαiyixiTx+b=∑i=1nαiyi〈xi,x〉+b

〈⋅〉〈·〉表示向量内积。这个形式的有趣之处在于,对新点x的预测,只需要计算它与训练数据点的内积即可。因为所有非支持向量所对应的系数αα都是0,因此对于新点的内积计算实际上只要针对少量的“支持向量”而不是所有的训练数据。

经过映射,分类函数变成

f(x)=∑i=1nαiyi〈ϕ(xi),ϕ(x)〉+bf(x)=∑i=1nαiyi〈ϕ(xi),ϕ(x)〉+b

而αα可以通过求解如下问题得到:

maxα∑i=1nαi−12∑i,j=1nαiαjyiyj〈ϕ(xi),ϕ(xj)〉s.t.αi≥0,i=1,…,nmaxα∑i=1nαi−12∑i,j=1nαiαjyiyj〈ϕ(xi),ϕ(xj)〉s.t.αi≥0,i=1,…,n

这样,似乎是拿到非线性数据,就找一个适当的映射ϕϕ,把原来的数据映射到新空间中,再做线性SVM即可。不过这个适当的映射可不是好惹的。二维空间做映射,需要5个维度,三维空间做映射,需要19个维度,维度数目是爆炸性增长的。到了无穷维,根本无法计算。这个时候就需要核函数出马了。

观察上式,映射只是一个中间过程,我们实际需要的是计算内积。如果有一种方式可以在特征空间中直接计算内积。就能很好地避免维数灾难了,这样直接计算的方法称为核函数方法。

核是一个函数κκ,对所有x1x1,x2x2,满足

κ(x1,x2)=〈ϕ(x1),ϕ(x2)〉κ(x1,x2)=〈ϕ(x1),ϕ(x2)〉

,这里ϕϕ是从xx到内积特征空间FF的映射。

几个常用的核函数

通常人们会从一些常用的核函数中选择(根据问题和数据的不同,选择不同的参数,实际上就是得到了不同的核函数),例如:

高斯核κ(x1,x2)=exp(−|x1−x2|22σ2)κ(x1,x2)=exp⁡(−|x1−x2|22σ2),这个空间会将原始空间映射到无穷维空间。不过,如果σσ选得很大的话,高次特征上的权重实际上衰减得非常快,所以实际上(数值上近似一下)相当于一个低维的空间;反过来,如果σσ选得很小的话,则可以将任意的数据映射为线性可分。当然,这不一定是好事,因为随之而来的可能是非常严重的过拟合问题。不过,总的来说,通过调控参数,高斯核实际上具有相当的灵活性,也是使用最广泛的核函数之一。下图所示的例子便是把低维空间不可分数据通过高斯核函数映射到了高维空间:多项式核

κ(x1,x2)=(〈x1,x2〉+R)dκ(x1,x2)=(〈x1,x2〉+R)d

,这个核所对应的映射实际上是可以写出来的,该空间的维度是

(m+dd)(m+dd),其中mm是原始空间的维度。线性核κ(x1,x2)=〈x1,x2〉κ(x1,x2)=〈x1,x2〉,这实际上就是原始空间中的内积。这个核存在的主要目的是使得“映射后空间中的问题”和“映射前空间中的问题”两者在形式上统一起来了(意思是说,咱们有的时候,写代码,或写公式的时候,只要写个模板或通用表达式,然后再代入不同的核,便可以了,于此,便在形式上统一了起来,不用再分别写一个线性的,和一个非线性的)。

核函数的本质

总结一下核函数,实际是三点:

实际中,当我们遇到线性不可分的样例,常用做法是把样例特征映射到高维空间中但如果凡是遇到线性不可分的样例,一律映射到高维空间,那么这个维度大小是会高到可怕的此时,核函数就隆重登场了,核函数的价值在于它虽然也是将特征进行从低维到高维的转换,但核函数绝就绝在它事先在低维上进行计算,而将实质上的分类效果表现在了高维上,也就如上文所说的避免了直接在高维空间中的复杂计算。

处理噪音

回顾此前的介绍,SVM用来处理线性可分的问题。后来为了处理非线性数据,使用核函数将原始数据映射到高维空间,转化为线性可分的问题。但是有时候,并不是数据本身是非线性结构的,而只是因为数据有噪音。对于这种偏离正常位置很远的数据点,我们称之为outlier。超平面本身就是只有少数几个支持向量组成,如果支持向量里存在outlier,就会有严重影响。如下图:

用黑圈圈起来的那个蓝点就是一个outlier,它偏离了自己原本所应该的那个半空间,如果直接忽略掉,原本的分隔超平面还是挺好的,但是由于这个outlier的出现,导致分隔超平面不得不被挤歪了,变成黑色虚线所示,同时margin也相应变小了。更严重的是,如果outlier再往右上移动一些距离的话,将无法构造出能将数据分开的超平面来。

为了处理这种情况,SVM允许数据点在一定程度上偏离一下超平面。上图中,黑色实线所对应的距离,就是该outlier偏离的距离,如果把它移动回来,就刚好落在原来的超平面上,而不会使超平面发生变形了。具体来说,原来的约束条件变成:

yi(wTxi+b)≥1−ξi,i=1,…,nyi(wTxi+b)≥1−ξi,i=1,…,n

其中ξi≥0ξi≥0称为松弛变量,对应数据点xixi允许偏离的函数间隔的量。对于一般的数据(非支持向量,也非outlier),这个值就是0。如果ξiξi任意大的话,那任意的超平面都是符合要求的。所以,我们在原来的目标函数后面加上一项,使得这些ξiξi的总和也要最小:

min12||w||2+C∑i=1nξimin12||w||2+C∑i=1nξi

,其中C是一个参数,用于控制目标函数中两项(寻找margin最大的超平面和保证数据点偏差最小)之间的权重。注意,ξiξi是需要优化的变量,而CC是一个事先确定好的常量。完整的目标函数是:

min12||w||2+C∑i=1nξis.t.yi(wTxi+b)≥1−ξi,i=1,…,nmin12||w||2+C∑i=1nξis.t.yi(wTxi+b)≥1−ξi,i=1,…,n

通过拉格朗日对偶求解,

w=∑i=1nαiyixiw=∑i=1nαiyixi

∑i=1nαiy=0∑i=1nαiy=0

求解的问题可以变换为

maxα∑i=1nαi−12∑i,j=1nαiαjyiyjxTixjs.t.0≤αi≤C,i=1,…,nmaxα∑i=1nαi−12∑i,j=1nαiαjyiyjxiTxjs.t.0≤αi≤C,i=1,…,n

对比之前的结果,只不过是αα多了一个上限CC。

小结

SVM是一个最大间距分类器。在线性可分的情况下,它的目标函数是min12|w|2s.t.,yi(wTxi+b)≥1,i=1,…,nmin12|w|2s.t.,yi(wTxi+b)≥1,i=1,…,n,较好的求解方法是转换为拉格朗日对偶问题,并用SMO算法进行求解。在线性不可分的情况下,其基本思想是,将低维线性不可分的问题映射为高维可分的问题。具体实现办法是:利用核函数,在低维空间进行运算,而将实质上的分类效果表现在高维上。考虑到数据点中可能存在噪音的干扰,需要将目标函数中加入松弛变量而求解的思路和方法不变。

二、代码及结果

环境是python3.x+sklearn+pythcarm

# -*- coding: utf-8 -*-# @Time : /8/23 10:38# @Author : Barry# @File : mnist_svm.py# @Software: PyCharm Community Editionimport pickleimport gzip# Third-party librariesimport numpy as npdef load_data():"""返回包含训练数据、验证数据、测试数据的元组的模式识别数据训练数据包含50,000张图片,测试数据和验证数据都只包含10,000张图片"""f = gzip.open('./MNIST_data/mnist.pkl.gz', 'rb')training_data, validation_data, test_data = pickle.load(f,encoding='bytes')f.close()return (training_data, validation_data, test_data)# Third-party librariesfrom sklearn import svmimport timedef svm_baseline():print (time.strftime('%Y-%m-%d %H:%M:%S') )training_data, validation_data, test_data = load_data()# 传递训练模型的参数,这里用默认的参数clf = svm.SVC(C=100.0, kernel='rbf', gamma=0.03)# clf = svm.SVC(C=8.0, kernel='rbf', gamma=0.00,cache_size=8000,probability=False)# 进行模型训练clf.fit(training_data[0], training_data[1])# test# 测试集测试预测结果predictions = [int(a) for a in clf.predict(test_data[0])]num_correct = sum(int(a == y) for a, y in zip(predictions, test_data[1]))print ("%s of %s test values correct." % (num_correct, len(test_data[1])))print (time.strftime('%Y-%m-%d %H:%M:%S'))if __name__ == "__main__":svm_baseline()

运行结果:

-08-23 13:33:329848 of 10000 test values correct.-08-23 13:43:20

准确率大约98.48%

三、SVM识别MNIST算法过程

SVM分类算法以另一个角度来考虑问题。其思路是获取大量的手写数字,常称作训练样本,然后开发出一个可以从这些训练样本中进行学习的系统。换言之,SVM使用样本来自动推断出识别手写数字的规则。随着样本数量的增加,算法可以学到更多关于手写数字的知识,这样就能够提升自身的准确性。

本文采用的数据集就是著名的“MNIST数据集”。这个数据集有60000个训练样本数据集和10000个测试用例。直接调用scikit-learn库中的SVM,使用默认的参数,1000张手写数字图片,判断准确的图片就高达9435张。

通常,对于分类问题。我们会将数据集分成三部分,训练集、测试集、交叉验证集。用训练集训练生成模型,用测试集和交叉验证集进行验证模型的准确性。

需要说明的是,svm.SVC()函数的几个重要参数。直接用help命令查看一下文档,这里我稍微翻译了一下:

C : 浮点型,可选 (默认=1.0)。误差项的惩罚参数C

kernel : 字符型, 可选 (默认=’rbf’)。指定核函数类型。只能是’linear’, ‘poly’, ‘rbf’, ‘sigmoid’, ‘precomputed’ 或者自定义的。如果没有指定,默认使用’rbf’。如果使用自定义的核函数,需要预先计算核矩阵。

degree : 整形, 可选 (默认=3)。用多项式核函数(‘poly’)时,多项式核函数的参数d,用其他核函数,这个参数可忽略

gamma : 浮点型, 可选 (默认=0.0)。’rbf’, ‘poly’ and ‘sigmoid’核函数的系数。如果gamma是0,实际将使用特征维度的倒数值进行运算。也就是说,如果特征是100个维度,实际的gamma是1/100。

coef0 : 浮点型, 可选 (默认=0.0)。核函数的独立项,’poly’ 和’sigmoid’核时才有意义。

可以适当调整一下SVM分类算法,看看不同参数的结果。当我的参数选择为C=100.0, kernel=’rbf’, gamma=0.03时,预测的准确度就已经高达98.5%了。

相同的C,gamma越大,分类边界离样本越近。相同的gamma,C越大,分类越严格。

下图是不同C和gamma下分类器交叉验证准确率的热力图

由图可知,模型对gamma参数是很敏感的。如果gamma太大,无论C取多大都不能阻止过拟合。当gamma很小,分类边界很像线性的。取中间值时,好的模型的gamma和C大致分布在对角线位置。还应该注意到,当gamma取中间值时,C取值可以是很大的。

在实际项目中,这几个参数按一定的步长,多试几次,一般就能得到比较好的分类效果了。

小结

回顾一下整个问题。我们进行了如下操作。对数据集分成了三部分,训练集、测试集和交叉验证集。用SVM分类模型进行训练,依据测试集和验证集的预测结果来优化参数。依靠sklearn这个强大的机器学习库,我们也能解决手写识别这么高大上的问题了。事实上,我们只用了几行简单代码,就让测试集的预测准确率高达98.5%。

事实上,就算是一般性的机器学习问题,我们也是有一些一般性的思路的,如下:

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。