700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > 【Python】利用高斯朴素贝叶斯模型实现光学字符识别

【Python】利用高斯朴素贝叶斯模型实现光学字符识别

时间:2022-12-20 07:32:21

相关推荐

【Python】利用高斯朴素贝叶斯模型实现光学字符识别

光学字符识别问题:手写数字识别。简单点说,这个问题包括图像中字符的定位和识别两部分。为了演示方便,我们选择使用 Scikit-Learn 中自带的手写数字数据集。

1.加载并可视化手写数字

首先用 Scikit-Learn 的数据获取接口加载数据,并简单统计一下:

>>>from sklearn.datasets import load_digits>>>digits = load_digits()>>>digits.images.shape(1797, 8, 8)

这份图像数据是一个三维矩阵:共有 1797 个样本,每张图像都是 8 像素 ×8 像素。对前100 张图进行可视化:

>>>import matplotlib.pyplot as plt>>>fig,axes = plt.subplots(10,10, figsize=(8, 8),subplot_kw={'xticks':[], 'yticks':[]},gridspec_kw=dict(hspace=0.1, wspace=0.1))>>>for i, ax in enumerate(axes.flat):ax.imshow(digits.images[i], cmap='binary', interpolation='nearest')ax.text(0.05, 0.05, str(digits.target[i]),transform=ax.transAxes, color='green')

为了在 Scikit-Learn 中使用数据,需要一个维度为 [n_samples, n_features] 的二维特征矩阵——可以将每个样本图像的所有像素都作为特征,也就是将每个数字的 8 像素 ×8 像素平铺成长度为 64 的一维数组。另外,还需要一个目标数组,用来表示每个数字的真实值(标签)。这两份数据已经放在手写数字数据集的 data 与 target 属性中,直接使用即可:

>>>X = digits.data>>>X.shape(1797, 64)>>>y = digits.target>>>y.shape(1797, )

2.无监督学习:降维

虽然我们想对具有 64 维参数空间的样本进行可视化,但是在如此高维度的空间中进行可视化十分困难。因此,我们需要借助无监督学习方法将维度降到二维。这次试试流形学习算法中的 Isomap算法对数据进行降维:

>>>from sklearn.manifold import Isomap>>>iso = Isomap(n_components=2)>>>iso.fit(digits.data)>>>data_projected = iso.transform(digits.data)>>>data_projected.shape(1797, 2)

现在数据已经投影到二维。把数据画出来,看看从结构中能发现什么:

>>>plt.scatter(data_projected[:, 0], data_projected[:, 1], c=digits.target,edgecolor='none', alpha=0.5,cmap=plt.cm.get_cmap('Spectral',10))>>>plt.colorbar(label='digit label', ticks=range(10))>>>plt.clim(-0.5, 9.5)

这幅图呈现出了非常直观的效果,让我们知道数字在 64 维空间中的分离(可识别)程度。虽然有些瑕疵,但从总体上看,各个数字在参数空间中的分离程度还是令人满意的。这其实告诉我们:用一个非常简单的有监督分类算法就可以完成任务。下面来演示一下。

3. 数字分类

我们需要找到一个分类算法,对手写数字进行分类。和前面学习鸢尾花数据一样,先将数据分成训练集和测试集,然后用高斯朴素贝叶斯模型来拟合:

>>>from sklearn.model_selection import train_test_split>>>Xtrain,Xtest,ytrain,ytest = train_test_split(X,y,random_state=0)>>>from sklearn.naive_bayes import GaussianNB>>>model = GaussianNB()>>>model.fit(Xtrain,ytrain)>>>y_model = model.predict(Xtest)

模型预测已经完成,现在用模型在训练集中的正确识别样本量与总训练样本量进行对比,获得模型的准确率:

>>>from sklearn.metrics import accuracy_score>>>accuracy_score(ytest,y_model)0.8333333333333334

可以看出,通过一个非常简单的模型,数字识别率就可以达到 80% 以上!但仅依靠这个指标,我们无法知道模型哪里做得不够好,解决这个问题的办法就是用混淆矩阵(confusion matrix)。可以用 Scikit-Learn 计算混淆矩阵,然后用 Seaborn 画出来:

>>>from sklearn.metrics import confusion_matrix>>>mat = confusion_matrix(ytest,y_model)>>>import seaborn as sns>>>sns.heatmap(mat, square=True, annot=True, cbar=False)>>>plt.xlabel('predicted value')>>>plt.ylabel('true value')

从图中可以看出,误判的主要原因在于许多数字 2 被误判成了数字 1 或数字 8。另一种显示模型特征的直观方式是将样本画出来,然后把预测标签放在左下角,用绿色表示预测正确,用红色表示预测错误:

>>>fig, axes = plt.subplots(10, 10, figsize=(8, 8),subplot_kw={'xticks':[], 'yticks':[]},gridspec_kw=dict(hspace=0.1, wspace=0.1))>>>test_images=Xtest.reshape(-1,8,8)>>>for i, ax in enumerate(axes.flat):ax.imshow(test_images[i], cmap='binary', interpolation='nearest')ax.text(0.05, 0.05, str(y_model[i]),transform=ax.transAxes,color='green' if (ytest[i] == y_model[i]) else 'red')

我们下次再见,如果还有下次的话!!!

欢迎关注微信公众号:516数据工作室

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。