700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > 机器学习|卷积神经网络(CNN) 手写体识别 (MNIST)入门

机器学习|卷积神经网络(CNN) 手写体识别 (MNIST)入门

时间:2021-04-29 15:47:42

相关推荐

机器学习|卷积神经网络(CNN) 手写体识别 (MNIST)入门

人工智能,机器学习,监督学习,神经网络,无论哪一个都是非常大的话题,都覆盖到可能就成一本书了,所以这篇文档只会包含在 RT-Thread 物联网操作系统,上面加载 MNIST 手写体识别模型相关的部分的知识。因为机器学习并不是纯软件开发,简单地调用库函数 API,需要有一定的理论支撑,如果完全不介绍理论部分,可能就不知道为什么模型要这样设计,模型出了问题应该怎样改善。所以我会分成应用和理论两部分连载 4 篇文章来写,相信大家坚持看下去会有很多收货的,我也尽可能把理论和应用都介绍清楚。

下面来看一下提纲:

神经网络相关理论训练卷积神经网络模型运行卷积神经网络模型总结

引言

这一部分会说明这个文档会包含哪些内容,以及不会包含哪些内容,因为人工智能,机器学习,监督学习,神经网络,无论哪一个都是非常大的话题,都覆盖到可能就成一本书了,所以这篇文档只会包含与 RT-Thread 上面加载 MNIST 手写体识别模型相关的部分。

当然,在每一部分的最后我也会给出参考文献,参考文献是个非常重要的部分,一方面它可以补充我没有介绍到的部分,另一方面也可以提供一些支撑,因为现在网上文档太多了,但是并不是每一篇文档都没有任何错误,比如大家对我写的一些公式结论表示怀疑的话,就可以在参考文献里找到更详细的推导证明。

这篇文档可能还是会非常长,因为机器学习并不是纯软件开发,简单地调用库函数 API,需要有一定的理论支撑,如果完全不介绍理论部分,可能就不知道为什么模型要这样设计,模型出了问题应该怎样改善。不过文档如果写太长大家可能很难有耐心看完,特别是理论部分会有很多公式,但是机器学习确实又对理论基础编程能力都有一些要求,相信坚持看下去还是会有很多收货的,我也尽可能把理论和应用都介绍清楚。

之后一篇文档就基本是纯实际应用了,不会有太多理论内容了,用 Darknet 机器学习框架训练一个目标检测模型。

如果对机器学习理论比较清楚,可以直接看第二部分 Keras 训练模型如果对 Keras 机器学习框架也比较熟悉了,可以直接跳转到第三部分 RT-Thread 加载 onnx 模型如果对 RT-Thread 和 onnx 模型都很熟悉了,那我们可以一起交流下如何在嵌入式设备上高效实现机器学习算法 :blush:

首先,简单介绍一下上面提到的各个话题的范围 (Domain),人工智能 (Artifitial Intelligence) 是最大的话题,如果用一张图来说明的话:

然后机器学习 (Machine Learning) 就是这篇文档的主题了,但是机器学习依旧是一个非常大的话题:

这里简单介绍一下上面提到的三种类型:

监督学习 (Supervised Learning): 这应当是应用最多的领域了,例如人脸识别,我提前先给你大量的图片,然后告诉你当中哪些包含了人脸,哪些不包含,你从我给的照片中总结出人脸的特征,这就是训练过程。最后我再提供一些从来没有见过的图片,如果算法训练得好的话,就能很好的区分一张图片中是否包含人脸。所以监督学习最大的特点就是有训练集,告诉模型什么是对的,什么是错的。

非监督学习 (Unsupervised Learning): 例如网上购物的推荐系统,模型会对我的浏览记录进行分类,然后自动向我推荐相关的商品。非监督学习最大的特点就是没有一个标准答案,比如水杯既可以分类为日用品,也可以分类为礼品,都没有问题。

强化学习 (Reinforcement Learnong): 强化学习应当是机器学习当中最吸引人的一个部分了,例如 Gym 上就有很多训练电脑自己玩游戏最后拿高分的例子。强化学习主要就是通过试错 (Action),找到能让自己收益最大的方法,这也是为什么很多都例子都是电脑玩游戏。

所以文档后面介绍的都是关于监督学习,因为手写体识别需要有一些训练集告诉我这些图像实际上应该是什么数字,不过监督学习的方法也有很多,主要有分类和回归两大类:

分类 (Classification):例如手写体识别,这类问题的特点在于最后的结果是离散的,最后分类的数字只能是 0, 1, 2, 3 而不会是 1.414, 1.732 这样的小数。

回归 (Regression):例如经典的房价预测,这类问题得到的结果是连续的,例如房价是会连续变化的,有无限多种可能,不像手写体识别那样只有 0-9 这 10 种类别。

这样看来,接下来介绍的手写体识别是一个分类问题。但是做分类算法也非常多,这篇文章要介绍的是应用非常多也相对成熟的神经网络 (Neural Network)。

人工神经网络 (Artifitial Neural Network):这是个比较通用的方法,可以应用在各个领域做数据拟合,但是像图像和语音也有各自更适合的算法。

卷积神经网络 (Convolutional Neural Network):主要应用在图像领域,后面也会详细介绍。

循环神经网络 (Recurrent Neural Network):比较适用于像声音这样的序列,因此在语言识别领域应用比较多。

最后总结一下,这篇文档介绍的是人工智能下面发展比较快的机器学习分支,然后解决的是机器学习监督学习下面的分类问题,用的是神经网络里的卷积神经网络方法

神经网络相关理论

这一部分主要介绍神经网络的整个运行流程,怎么准备训练集,什么是训练,为什么要训练,怎么进行训练,以及训练之后得到了什么。

1.1 线性回归 (Linear Regression)

1.1.1 回归模型

要做机器学习训练预测,我们首先得知道自己训练的模型是什么样的,还是以最经典的线性回归模型为例,后面的人工神经网络 (ANN) 其实可以看做多个线性回归组合。那么什么是线性回归模型呢?

比如下面图上这些散点,望能找到一条直线进行拟合,线性回归拟合的模型就是:

$$y = kx + b$$

这样如果以后有一个点 x = 3,不在图上这些点覆盖的区域,我们也可以通过训练好的线性回归模型预测出对应的 y。不过上面的公式通常使用另外一种表示方法,最终的预测值也就是 y 通常用 $h_\theta$ (hypothesis) 表示,而它的下标 $\theta$ 代表不同训练参数也就是 k, b。这样模型就成了:

$$h\theta = \theta0+\theta1x1$$

但是这样表示模型还不够通用,比如 x 可能不是一个一维向量,例如经典的房价预测,我们要知道房价,可能需要房子大小,房间数等很多因素,因此把上面的用更通用的方法表示:

$$h\theta = \theta0x0 + \theta1x1= \begin{bmatrix} \theta0 & \theta1 \ \end{bmatrix} · \begin{bmatrix} x0 \ x_1 \ \end{bmatrix} = \theta^T x$$

这就是线性回归的模型了,只要会向量乘法,上面的公式计算起来还是挺轻松的。

顺便一提,$\theta$ 需要一个转置 $\theta^T$,是因为我们通常都习惯使用列向量。上面这个公式和 $y=kx+b$ 其实是一样的,只是换了一种表示方法而已,不过这种表示方法就更加通用,而且也更加简洁优美了:

$$ h_\theta =\theta^T x $$

1.1.2 评价指标

为了让上面的模型能够很好的拟合这些散点,我们的目标就是改变模型参数 $\theta0$ 和 $\theta1$,也就是这条直线的斜率和截距,让它能很好的反应散点的趋势,下面的动画就很直观的反应了训练过程。

可以看到,一开始是一条几乎水平的直线,但是慢慢地它的斜率和截距就移动到一个比较好的位置,那么问题来了,我们要怎么评价这条直线当前的位置满不满足我们的需求呢?

一个很直接的想法就是求出所有散点实际值 y 和我们模型的测试值 $h_\theta$ 相差的绝对值,这个评价指标我们就称为损失函数 $J(\theta)$ (cost function):

$$ J(\theta) = \frac{1}{2m}\sum{i=1}^m(h\theta(x^i)-y^i )^2 $$

函数右边之所以除以了2是为了求倒数的时候更加方便,因为如果右边的公式求导,上面的平方就会得到一个2,刚好和分母里的2抵消了。

这样我们就有了评价指标了,损失函数计算出来的值越小越好,这样就知道当前的模型是不时能很好地满足需求,下一步就是告诉模型该如何往更好的方向优化了,这就是训练 (Training) 过程。

1.1.3 模型训练

为了让模型的参数 $\theta$ 能够往更好的方向运动,也就是很自然的想法就是向下坡的方向走,比如上面的损失函数其实是个双曲线,我们只要沿着下坡的方向走总能走到函数的最低点:

那么什么是"下坡"的方向呢?其实就是导数的方向,从上面的动画也可以看出来,黑点一直是沿着切线方向逐渐走到最低点的,如果我们对损失函数求导,也就是对 $J(\theta)$ 求导:

$$ J(\theta)^\prime= \frac{1}{m}*x_\theta(h-y)$$

我们现在知道 $\theta$ 应该往哪个方向走了,那每一次应该走多远呢?就像上面的动画那样,黑点就算知道了运动方向,每一次运动多少也是需要确定的。这个每次运动的多少称之为学习速率 $\alpha$ (learning rate),这样我们就知道参数每次应该向哪个方向运动多少了:

$$ \theta = \theta - \alpha \frac{1}{m} * (x^T * (h-y))$$

这种训练方法就是很有名的梯度下降法(Gradient Descent),当然现在也有很多改进的训练方法例如 Adam,其实原理都差不多,这里就不做过多的介绍了。

1.1.4 总结

机器学习的流程总结出来就是,我们先要设计一个模型,然后定义一个评价指标称之为损失函数,这样我们就知道怎么去判断模型的好坏,接下来就是用一种训练方法,让模型参数能朝着能让损失函数减少的方向运动,当损失函数几乎不再减少的时候,我们就可以认为训练结束了。最终训练得到的就是模型的参数,使用训练好的模型我们就可以对其他的数据进行预测了。

1.2 非线性回归 (Logistic Regression)

我们回到手写体识别的例子,上面介绍的线性回归最后得到的是一个连续的数值,但是手写体识别最后的目标是得到一个离散的数值,也就是 0-9,那么这要怎么做到呢?

$$ h_\theta =\theta^T x $$

这个就是上一部分的模型,其实很简单,只需要在最后的结果再加一个 sigmoid 函数,把最终得到的结果限制在 0-1 就可以了。

就像上面图中的公式那样,sigmoid 函数就是:

$$g(z) = \frac{1}{1+e^{-t}}$$

如果把它应用到线性回归的模型,我们就得到了一个非线性回归模型,也就是 Logistic Regression:

$$ h_\theta = g(\theta^T x) $$

这样就可以确保我们最后得到的结果肯定是在 0-1 之间了,然后我们可以定义如果最后的结果大于 0.5 就是 1,小于 0.5 就是 0。

1.3 人工神经网络 (ANN)

现在我们介绍了连续的线性回归模型 Linear Regression,和离散的非线性回归模型 Logistic Regression,模型都非常简单,写在纸上也就不过几厘米的长度。那么这么简单的模型到底是怎么组合成非常好用的神经网络的呢?

其实上面的模型可以看做是只有一层的神经网络,我们输入 $x$ 经过一次计算就得到输出 $h_\theta$ 了:

$$ h_\theta = g(\theta^T x) $$

如果我们不那么快得到计算结果,而是在中间再插入一层呢?就得到了有一层隐藏层的神经网络了。

上面这张图里,我们用 $a$ 代表激活函数 (activation function) 的输出,激活函数也就是上一部分提到的 sigmoid 函数,为了将输出限制在 0-1,如果不这么做,很有可能经过几层神经网络的计算,输出值就爆炸到一个很大很大的数了。当然除了 sigmoid 函数外,激活函数还有很多,例如下一部分在卷积神经网络里非常常用的 Relu。

另外,我们用带括号的上标代表神经网络的层数。例如 $a^{(1)}$ 代表第一层神经网络输出。当然,第一层就是输入层,并不需要经过任何计算,所以可以看到图上的 $a^{(1)}=x$,第一层的激活函数输出直接就是我们的输入 $x$。但是,$\theta^{(1)}$ 不是代表第一层的参数,而是第一层与第二层之间的参数,毕竟参数存在于两层网络之间的计算过程。

于是,我们可以总结一下上面的神经网络结构:

输入层:$a^{(1)} = x$隐藏层:$a^{(2)} = g(\theta^{(1)}a^{(1)})$输出层:$h(\theta) = g(\theta^{(2)}a^{(2)})$

如果我们设置最后的输出层节点是 10 个,那就刚好可以用来表示 0-9 这 10 个数字了。

如果我们再多增加几个隐藏层,是不是看起来就有点像是互相连接的神经元了?

如果我们再深入一点 Go Deeper (论文里作者提到,他做深度学习的灵感其实源自于盗梦空间)

这样我们就得到一个深度神经网络了:

如果你想知道,具体应当选多少层隐藏层,每个隐藏层应该选几个节点,这就跟你从哪里来,要到哪里去一样,是神经网络的终极问题了。

最后,神经网络的训练方法是用的反向传播 (Back Propagation),如果感兴趣可以在这里找到更加详细的介绍。

1.4 卷积神经网络 (CNN)

终于到了后面会用到的卷积神经网络了,从前面的介绍可以看到,其实神经网络的模型非常简单,用到的数学知识也不多,只需要知道矩阵乘法,函数求导

Cov2D Dropout Relu Maxpooling Softmax Flatten

1.5 参考文献

斯坦福经典机器学习入门视频

线性回归

反向传播

本文首发于 GitChat,未经授权,转载需与 GitChat 联系。

阅读全文: /gitchat/activity/5d6638aac1bcac51e63e2901

您还可以下载 CSDN 旗下精品原创内容社区 GitChat App , GitChat 专享技术内容哦。

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。