700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > svm手写数字识别_KNN 算法实战篇如何识别手写数字

svm手写数字识别_KNN 算法实战篇如何识别手写数字

时间:2022-11-07 07:51:55

相关推荐

svm手写数字识别_KNN 算法实战篇如何识别手写数字

上篇文章介绍了KNN 算法的原理,今天来介绍如何使用KNN 算法识别手写数字

1,手写数字数据集

手写数字数据集是一个用于图像处理的数据集,这些数据描绘了[0, 9]的数字,我们可以用KNN 算法来识别这些数字。

MNIST是完整的手写数字数据集,其中包含了60000 个训练样本和10000 个测试样本。

[MNIST](/exdb/mnist/)

sklearn中也有一个自带的手写数字数据集

共包含 1797 个数据样本,每个样本描绘了一个8*8像素的[0, 9]的数字。

每个样本由 65 个数字组成:

前 64 个数字是特征数据,特征数据的范围是[0, 16]

最后一个数字是目标数据,目标数据的范围是[0, 9]

[手写数字数据集]

(/scikit-learn/scikit-learn/blob/master/sklearn/datasets/data/digits.csv.gz)

我们抽出 5 个样本来看下:

0,0,5,13,9,1,0,0,0,0,13,15,10,15,5,0,0,3,15,2,0,11,8,0,0,4,12,0,0,8,8,0,0,5,8,0,0,9,8,0,0,4,11,0,1,12,7,0,0,2,14,5,10,12,0,0,0,0,6,13,10,0,0,0,00,0,0,12,13,5,0,0,0,0,0,11,16,9,0,0,0,0,3,15,16,6,0,0,0,7,15,16,16,2,0,0,0,0,1,16,16,3,0,0,0,0,1,16,16,6,0,0,0,0,1,16,16,6,0,0,0,0,0,11,16,10,0,0,10,0,0,4,15,12,0,0,0,0,3,16,15,14,0,0,0,0,8,13,8,16,0,0,0,0,1,6,15,11,0,0,0,1,8,13,15,1,0,0,0,9,16,16,5,0,0,0,0,3,13,16,16,11,5,0,0,0,0,3,11,16,9,0,20,0,7,15,13,1,0,0,0,8,13,6,15,4,0,0,0,2,1,13,13,0,0,0,0,0,2,15,11,1,0,0,0,0,0,1,12,12,1,0,0,0,0,0,1,10,8,0,0,0,8,4,5,14,9,0,0,0,7,13,13,9,0,0,30,0,0,1,11,0,0,0,0,0,0,7,8,0,0,0,0,0,1,13,6,2,2,0,0,0,7,15,0,9,8,0,0,5,16,10,0,16,6,0,0,4,15,16,13,16,1,0,0,0,0,3,15,10,0,0,0,0,0,2,16,4,0,0,4

使用该数据集,需要先加载:

>>>fromsklearn.datasetsimportload_digits>>> digits = load_digits()

查看第一个图像数据:

>>>digits.images[0]array([[0.,0.,5.,13.,9.,1.,0.,0.],[0.,0.,13.,15.,10.,15.,5.,0.],[0.,3.,15.,2.,0.,11.,8.,0.],[0.,4.,12.,0.,0.,8.,8.,0.],[0.,5.,8.,0.,0.,9.,8.,0.],[0.,4.,11.,0.,1.,12.,7.,0.],[0.,2.,14.,5.,10.,12.,0.,0.], [ 0., 0., 6., 13., 10., 0., 0., 0.]])

我们可以用matplotlib将该图像画出来:

>>>importmatplotlib.pyplotasplt>>>plt.imshow(digits.images[0])>>> plt.show()

[matplotlib](/)

画出来的图像如下,代表0

2,sklearn 对 KNN 算法的实现

sklearn库的neighbors模块实现了KNN相关算法,其中:

KNeighborsClassifier类用于分类问题

KNeighborsRegressor类用于回归问题

这两个类的构造方法基本一致,这里我们主要介绍KNeighborsClassifier类,原型如下:

KNeighborsClassifier( n_neighbors=5, weights='uniform', algorithm='auto', leaf_size=30, p=2, metric='minkowski', metric_params=None, n_jobs=None, **kwargs)

来看下几个重要参数的含义:

n_neighbors:即KNN中的 K 值,一般使用默认值 5。

weights:用于确定邻居的权重,有三种方式:

weights=uniform,表示所有邻居的权重相同。

weights=distance,表示权重是距离的倒数,即与距离成反比。

自定义函数,可以自定义不同距离所对应的权重,一般不需要自己定义函数。

algorithm:用于设置计算邻居的算法,它有四种方式:

algorithm=auto,根据数据的情况自动选择适合的算法。

algorithm=kd_tree,使用KD 树算法。

KD 树是一种多维空间的数据结构,方便对数据进行检索。

KD 树适用于维度较少的情况,一般维数不超过 20,如果维数大于 20 之后,效率会下降。

algorithm=ball_tree,使用球树算法。

KD 树一样都是多维空间的数据结构。

球树更适用于维度较大的情况。

algorithm=brute,称为暴力搜索

它和KD 树相比,采用的是线性扫描,而不是通过构造树结构进行快速检索。

缺点是,当训练集较大的时候,效率很低。

leaf_size:表示构造KD 树球树时的叶子节点数,默认是 30。

调整 leaf_size 会影响树的构造和搜索速度。

[KD树]

(https://scikit-/stable/modules/generated/sklearn.neighbors.KDTree.html)

[球树]

(https://scikit-/stable/modules/generated/sklearn.neighbors.BallTree.html)

3,构造 KNN 分类器

首先加载数据集:

from sklearn.datasets import load_digitsdigits=load_digits()data=digits.data#特征集target = digits.target # 目标集

将数据集拆分为训练集(75%)和测试集(25%):

fromsklearn.model_selectionimporttrain_test_splittrain_x,test_x,train_y,test_y=train_test_split( data, target, test_size=0.25, random_state=33)

构造KNN分类器:

from sklearn.neighbors import KNeighborsClassifier#采用默认参数knn = KNeighborsClassifier()

拟合模型:

knn.fit(train_x, train_y)

预测数据:

predict_y = knn.predict(test_x)

计算模型准确度:

from sklearn.metrics import accuracy_scorescore=accuracy_score(test_y,predict_y)print score # 0.98

最终计算出来模型的准确度是98%,准确度还是不错的。

4,总结

本篇文章使用KNN 算法处理了一个实际的分类问题,主要介绍了以下几点:

介绍了sklearn中自带的手写数字集,并用matplotlib模块画出了数字图像。

介绍了sklearnneighbors.KNeighborsClassifier类的用法。

使用KNeighborsClassifier来识别手写数字。

(本节完。)

点击查看往期内容回顾

KNN 算法-理论篇-如何给电影进行分类

决策树算法-理论篇-如何计算信息纯度

决策树算法-实战篇-鸢尾花及波士顿房价预测

朴素贝叶斯分类-理论篇-如何通过概率解决分类问题

朴素贝叶斯分类-实战篇-如何进行文本分类

欢迎关注作者公众号,获取更多技术干货。

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。