700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > 06-迁移学习:用基于 ImageNet 训练的权重的 MobileNet V2 模型进行猫狗分类

06-迁移学习:用基于 ImageNet 训练的权重的 MobileNet V2 模型进行猫狗分类

时间:2020-06-13 16:26:44

相关推荐

06-迁移学习:用基于 ImageNet 训练的权重的 MobileNet V2 模型进行猫狗分类

文章目录

1. 数据预处理1.1 数据下载 2. 从预训练卷积网络创建基础模型3. 特征提取3.1 冻结卷积基3.2 添加分类头 3. 编译模型4. 训练模型5. 微调5.1 解冻模型的顶层5.2 编译模型5.3 继续训练模型 6. 评估和预测7. 总结

迁移学习把之前学到的知识,应用到新的问题上,给出较好的解决方法。

本文的例子是基于模型的迁移,就是重新利用模型里的参数。该类方法在神经网络里面用的特别多,因为神经网络的结构可以直接进行迁移。比如大家熟知的 finetune 就是模型参数迁移的很好的体现。

关于迁移学习更多的资料可参考:

/epubit17/article/details/110390339/qq_42951560/article/details/110244616

1. 数据预处理

1.1 数据下载

在这里,我们将使用包含数千个猫和狗图像的数据集。下载并解压缩包含图像的 zip 文件,然后使用tf.keras.preprocessing.image_dataset_from_directory效用函数创建一个tf.data.Dataset进行训练和验证。

_URL = '/mledu-datasets/cats_and_dogs_filtered.zip'path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)print(os.path.dirname(path_to_zip)) # /public/home/zhaiyuxin/.keras/datasets

# 显示训练集中的前9个图像和标签class_names = train_dataset.class_namesplt.figure(figsize=(10, 10))# dataset.take(1):取第一个元素构建dataset(是第一个元素,不是随机的一个)# 从文件中读取数据形成train_dataset时是以为9为一个步长的,故这里的dataset.take(1)即前9个数据。for images, labels in train_dataset.take(1):for i in range(9):ax = plt.subplot(3, 3, i+1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")plt.show()# print(train_dataset.take(1)) # <TakeDataset shapes: ((None, 160, 160, 3), (None,)), types: (tf.float32, tf.int32)>

配置数据集以提高性能:使用缓冲预提取从磁盘加载图像,以免造成 I/O 阻塞。

AUTOTUNE = tf.data.AUTOTUNEtrain_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE) # prefetch: 数据准备和参数迭代并行执行validation_dataset = validation_dataset.prefetch(buffer_size=AUTOTUNE)test_dataset = test_dataset.prefetch(buffer_size=AUTOTUNE)

在这里我们使用数据扩充减少过拟合。当我们没有较大的图像数据集时,最好将随机但现实的转换应用于训练图像(例如旋转或水平翻转)来人为引入样本多样性。这有助于使模型暴露于训练数据的不同方面并减少过拟合。

data_augmentation = tf.keras.Sequential([tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),])

注:当您调用 model.fit 时,这些层仅在训练过程中才会处于有效状态。在 model.evaulate 或 model.fit 中的推断模式下使用模型时,它们处于停用状态。

我们将数据扩充重复应用于同一张图像查看其效果:

for image, _ in train_dataset.take(1):plt.figure(figsize=(10, 10))first_image = image[0]for i in range(9):ax = plt.subplot(3, 3, i + 1)augmented_image = data_augmentation(tf.expand_dims(first_image, 0)) # tf.expand_dims在axis=0轴给input增加一个维度plt.imshow(augmented_image[0] / 255)plt.axis('off')plt.show()

效果图如下所示:

之后,我们将使用tf.keras.applications.MobileNetV2作为基础模型。此模型期望像素值处于 [-1, 1] 范围内,但此时,图像中的像素值处于 [0, 255] 范围内。要重新缩放这些像素值,我们要使用模型随附的预处理方法。

preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input

2. 从预训练卷积网络创建基础模型

我们将根据 Google 开发的 MobileNet V2 模型来创建基础模型。此模型已基于 ImageNet 数据集进行预训练,ImageNet 数据集是一个包含 140 万个图像和 1000 个类的大型数据集。ImageNet 是一个研究训练数据集,具有各种各样的类别,例如jackfruitsyringe。此知识库将帮助我们对特定数据集中的猫和狗进行分类。

首先,您需要选择将 MobileNet V2 的哪一层用于特征提取。最后的分类层(在“顶部”,因为大多数机器学习模型的图表是从下到上的)不是很有用。相反,您将按照常见做法依赖于展平操作之前的最后一层。此层被称为“瓶颈层”。与最后一层/顶层相比,瓶颈层的特征保留了更多的通用性。

首先,实例化一个已预加载基于 ImageNet 训练的权重的MobileNet V2 模型。通过指定include_top=False参数,可以加载不包括顶部分类层的网络,这对于特征提取十分理想。

Keras的预训练权值模型用来进行预测、特征提取和微调,可用的模型有Xception、VGG16、ResNet50、MoblieNetV2等,这些模型的使用示例可以参考:/weixin_39506322/article/details/88640679

IMG_SHAPE = IMG_SIZE + (3,) # (160, 160, 3)base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE, #include_top=False, # 不包括顶层的全连接层weights='imagenet') # 'imagenet' 代表加载在 ImageNet 上预训练的权值。

此特征提取程序将每个 160x160x3 图像转换为 5x5x1280 的特征块。我们看看它对一批示例图像做了些什么:

image_batch, label_batch = next(iter(train_dataset))print(image_batch.shape) # (32, 160, 160, 3)feature_batch = base_model(image_batch)print(feature_batch.shape) # (32, 5, 5, 1280)

3. 特征提取

3.1 冻结卷积基

在编译和训练模型之前,冻结卷积基至关重要。冻结(通过设置layer.trainable = False)可避免在训练期间更新给定层中的权重。MobileNet V2 具有许多层,因此将整个模型的 trainable 标记设置为 False 会冻结所有这些层。

base_model.trainable = False

我们通过base_model.summary()查看模型的结构:

Model: "mobilenetv2_1.00_160"__________________________________________________________________________________________________Layer (type)Output Shape Param #Connected to ==================================================================================================input_1 (InputLayer) [(None, 160, 160, 3) 0 __________________________________________________________________________________________________Conv1 (Conv2D) (None, 80, 80, 32) 864 input_1[0][0]__________________________________________________________________________________________________bn_Conv1 (BatchNormalization) (None, 80, 80, 32) 128 Conv1[0][0] __________________________________________________________________________________________________Conv1_relu (ReLU)(None, 80, 80, 32) 0 bn_Conv1[0][0] __________________________________________________________________________________________________expanded_conv_depthwise (Depthw (None, 80, 80, 32) 288 Conv1_relu[0][0] __________________________________________________________________________________________________expanded_conv_depthwise_BN (Bat (None, 80, 80, 32) 128 expanded_conv_depthwise[0][0] __________________________________________________________________________________________________expanded_conv_depthwise_relu (R (None, 80, 80, 32) 0 expanded_conv_depthwise_BN[0][0] __________________________________________________________________________________________________expanded_conv_project (Conv2D) (None, 80, 80, 16) 512 expanded_conv_depthwise_relu[0][0__________________________________________________________________________________________________expanded_conv_project_BN (Batch (None, 80, 80, 16) 64expanded_conv_project[0][0]__________________________________________________________________________________________________block_1_expand (Conv2D) (None, 80, 80, 96) 1536 expanded_conv_project_BN[0][0] __________________________________________________________________________________________________block_1_expand_BN (BatchNormali (None, 80, 80, 96) 384 block_1_expand[0][0] __________________________________________________________________________________________________block_1_expand_relu (ReLU)(None, 80, 80, 96) 0 block_1_expand_BN[0][0]__________________________________________________________________________________________________block_1_pad (ZeroPadding2D)(None, 81, 81, 96) 0 block_1_expand_relu[0][0] __________________________________________________________________________________________________block_1_depthwise (DepthwiseCon (None, 40, 40, 96) 864 block_1_pad[0][0]__________________________________________________________________________________________________block_1_depthwise_BN (BatchNorm (None, 40, 40, 96) 384 block_1_depthwise[0][0]__________________________________________________________________________________________________block_1_depthwise_relu (ReLU) (None, 40, 40, 96) 0 block_1_depthwise_BN[0][0] __________________________________________________________________________________________________block_1_project (Conv2D) (None, 40, 40, 24) 2304 block_1_depthwise_relu[0][0]__________________________________________________________________________________________________block_1_project_BN (BatchNormal (None, 40, 40, 24) 96block_1_project[0][0] __________________________________________________________________________________________________block_2_expand (Conv2D) (None, 40, 40, 144) 3456 block_1_project_BN[0][0] __________________________________________________________________________________________________block_2_expand_BN (BatchNormali (None, 40, 40, 144) 576 block_2_expand[0][0] __________________________________________________________________________________________________block_2_expand_relu (ReLU)(None, 40, 40, 144) 0 block_2_expand_BN[0][0]__________________________________________________________________________________________________block_2_depthwise (DepthwiseCon (None, 40, 40, 144) 1296 block_2_expand_relu[0][0] __________________________________________________________________________________________________block_2_depthwise_BN (BatchNorm (None, 40, 40, 144) 576 block_2_depthwise[0][0]__________________________________________________________________________________________________block_2_depthwise_relu (ReLU) (None, 40, 40, 144) 0 block_2_depthwise_BN[0][0] __________________________________________________________________________________________________block_2_project (Conv2D) (None, 40, 40, 24) 3456 block_2_depthwise_relu[0][0]__________________________________________________________________________________________________block_2_project_BN (BatchNormal (None, 40, 40, 24) 96block_2_project[0][0] __________________________________________________________________________________________________block_2_add (Add)(None, 40, 40, 24) 0 block_1_project_BN[0][0] block_2_project_BN[0][0] __________________________________________________________________________________________________block_3_expand (Conv2D) (None, 40, 40, 144) 3456 block_2_add[0][0]__________________________________________________________________________________________________block_3_expand_BN (BatchNormali (None, 40, 40, 144) 576 block_3_expand[0][0] __________________________________________________________________________________________________block_3_expand_relu (ReLU)(None, 40, 40, 144) 0 block_3_expand_BN[0][0]__________________________________________________________________________________________________block_3_pad (ZeroPadding2D)(None, 41, 41, 144) 0 block_3_expand_relu[0][0] __________________________________________________________________________________________________block_3_depthwise (DepthwiseCon (None, 20, 20, 144) 1296 block_3_pad[0][0]__________________________________________________________________________________________________block_3_depthwise_BN (BatchNorm (None, 20, 20, 144) 576 block_3_depthwise[0][0]__________________________________________________________________________________________________block_3_depthwise_relu (ReLU) (None, 20, 20, 144) 0 block_3_depthwise_BN[0][0] __________________________________________________________________________________________________block_3_project (Conv2D) (None, 20, 20, 32) 4608 block_3_depthwise_relu[0][0]__________________________________________________________________________________________________block_3_project_BN (BatchNormal (None, 20, 20, 32) 128 block_3_project[0][0] __________________________________________________________________________________________________block_4_expand (Conv2D) (None, 20, 20, 192) 6144 block_3_project_BN[0][0] __________________________________________________________________________________________________block_4_expand_BN (BatchNormali (None, 20, 20, 192) 768 block_4_expand[0][0] __________________________________________________________________________________________________block_4_expand_relu (ReLU)(None, 20, 20, 192) 0 block_4_expand_BN[0][0]__________________________________________________________________________________________________block_4_depthwise (DepthwiseCon (None, 20, 20, 192) 1728 block_4_expand_relu[0][0] __________________________________________________________________________________________________block_4_depthwise_BN (BatchNorm (None, 20, 20, 192) 768 block_4_depthwise[0][0]__________________________________________________________________________________________________block_4_depthwise_relu (ReLU) (None, 20, 20, 192) 0 block_4_depthwise_BN[0][0] __________________________________________________________________________________________________block_4_project (Conv2D) (None, 20, 20, 32) 6144 block_4_depthwise_relu[0][0]__________________________________________________________________________________________________block_4_project_BN (BatchNormal (None, 20, 20, 32) 128 block_4_project[0][0] __________________________________________________________________________________________________block_4_add (Add)(None, 20, 20, 32) 0 block_3_project_BN[0][0] block_4_project_BN[0][0] __________________________________________________________________________________________________block_5_expand (Conv2D) (None, 20, 20, 192) 6144 block_4_add[0][0]__________________________________________________________________________________________________block_5_expand_BN (BatchNormali (None, 20, 20, 192) 768 block_5_expand[0][0] __________________________________________________________________________________________________block_5_expand_relu (ReLU)(None, 20, 20, 192) 0 block_5_expand_BN[0][0]__________________________________________________________________________________________________block_5_depthwise (DepthwiseCon (None, 20, 20, 192) 1728 block_5_expand_relu[0][0] __________________________________________________________________________________________________block_5_depthwise_BN (BatchNorm (None, 20, 20, 192) 768 block_5_depthwise[0][0]__________________________________________________________________________________________________block_5_depthwise_relu (ReLU) (None, 20, 20, 192) 0 block_5_depthwise_BN[0][0] __________________________________________________________________________________________________block_5_project (Conv2D) (None, 20, 20, 32) 6144 block_5_depthwise_relu[0][0]__________________________________________________________________________________________________block_5_project_BN (BatchNormal (None, 20, 20, 32) 128 block_5_project[0][0] __________________________________________________________________________________________________block_5_add (Add)(None, 20, 20, 32) 0 block_4_add[0][0]block_5_project_BN[0][0] __________________________________________________________________________________________________block_6_expand (Conv2D) (None, 20, 20, 192) 6144 block_5_add[0][0]__________________________________________________________________________________________________block_6_expand_BN (BatchNormali (None, 20, 20, 192) 768 block_6_expand[0][0] __________________________________________________________________________________________________block_6_expand_relu (ReLU)(None, 20, 20, 192) 0 block_6_expand_BN[0][0]__________________________________________________________________________________________________block_6_pad (ZeroPadding2D)(None, 21, 21, 192) 0 block_6_expand_relu[0][0] __________________________________________________________________________________________________block_6_depthwise (DepthwiseCon (None, 10, 10, 192) 1728 block_6_pad[0][0]__________________________________________________________________________________________________block_6_depthwise_BN (BatchNorm (None, 10, 10, 192) 768 block_6_depthwise[0][0]__________________________________________________________________________________________________block_6_depthwise_relu (ReLU) (None, 10, 10, 192) 0 block_6_depthwise_BN[0][0] __________________________________________________________________________________________________block_6_project (Conv2D) (None, 10, 10, 64) 12288 block_6_depthwise_relu[0][0]__________________________________________________________________________________________________block_6_project_BN (BatchNormal (None, 10, 10, 64) 256 block_6_project[0][0] __________________________________________________________________________________________________block_7_expand (Conv2D) (None, 10, 10, 384) 24576 block_6_project_BN[0][0] __________________________________________________________________________________________________block_7_expand_BN (BatchNormali (None, 10, 10, 384) 1536 block_7_expand[0][0] __________________________________________________________________________________________________block_7_expand_relu (ReLU)(None, 10, 10, 384) 0 block_7_expand_BN[0][0]__________________________________________________________________________________________________block_7_depthwise (DepthwiseCon (None, 10, 10, 384) 3456 block_7_expand_relu[0][0] __________________________________________________________________________________________________block_7_depthwise_BN (BatchNorm (None, 10, 10, 384) 1536 block_7_depthwise[0][0]__________________________________________________________________________________________________block_7_depthwise_relu (ReLU) (None, 10, 10, 384) 0 block_7_depthwise_BN[0][0] __________________________________________________________________________________________________block_7_project (Conv2D) (None, 10, 10, 64) 24576 block_7_depthwise_relu[0][0]__________________________________________________________________________________________________block_7_project_BN (BatchNormal (None, 10, 10, 64) 256 block_7_project[0][0] __________________________________________________________________________________________________block_7_add (Add)(None, 10, 10, 64) 0 block_6_project_BN[0][0] block_7_project_BN[0][0] __________________________________________________________________________________________________block_8_expand (Conv2D) (None, 10, 10, 384) 24576 block_7_add[0][0]__________________________________________________________________________________________________block_8_expand_BN (BatchNormali (None, 10, 10, 384) 1536 block_8_expand[0][0] __________________________________________________________________________________________________block_8_expand_relu (ReLU)(None, 10, 10, 384) 0 block_8_expand_BN[0][0]__________________________________________________________________________________________________block_8_depthwise (DepthwiseCon (None, 10, 10, 384) 3456 block_8_expand_relu[0][0] __________________________________________________________________________________________________block_8_depthwise_BN (BatchNorm (None, 10, 10, 384) 1536 block_8_depthwise[0][0]__________________________________________________________________________________________________block_8_depthwise_relu (ReLU) (None, 10, 10, 384) 0 block_8_depthwise_BN[0][0] __________________________________________________________________________________________________block_8_project (Conv2D) (None, 10, 10, 64) 24576 block_8_depthwise_relu[0][0]__________________________________________________________________________________________________block_8_project_BN (BatchNormal (None, 10, 10, 64) 256 block_8_project[0][0] __________________________________________________________________________________________________block_8_add (Add)(None, 10, 10, 64) 0 block_7_add[0][0]block_8_project_BN[0][0] __________________________________________________________________________________________________block_9_expand (Conv2D) (None, 10, 10, 384) 24576 block_8_add[0][0]__________________________________________________________________________________________________block_9_expand_BN (BatchNormali (None, 10, 10, 384) 1536 block_9_expand[0][0] __________________________________________________________________________________________________block_9_expand_relu (ReLU)(None, 10, 10, 384) 0 block_9_expand_BN[0][0]__________________________________________________________________________________________________block_9_depthwise (DepthwiseCon (None, 10, 10, 384) 3456 block_9_expand_relu[0][0] __________________________________________________________________________________________________block_9_depthwise_BN (BatchNorm (None, 10, 10, 384) 1536 block_9_depthwise[0][0]__________________________________________________________________________________________________block_9_depthwise_relu (ReLU) (None, 10, 10, 384) 0 block_9_depthwise_BN[0][0] __________________________________________________________________________________________________block_9_project (Conv2D) (None, 10, 10, 64) 24576 block_9_depthwise_relu[0][0]__________________________________________________________________________________________________block_9_project_BN (BatchNormal (None, 10, 10, 64) 256 block_9_project[0][0] __________________________________________________________________________________________________block_9_add (Add)(None, 10, 10, 64) 0 block_8_add[0][0]block_9_project_BN[0][0] __________________________________________________________________________________________________block_10_expand (Conv2D) (None, 10, 10, 384) 24576 block_9_add[0][0]__________________________________________________________________________________________________block_10_expand_BN (BatchNormal (None, 10, 10, 384) 1536 block_10_expand[0][0] __________________________________________________________________________________________________block_10_expand_relu (ReLU)(None, 10, 10, 384) 0 block_10_expand_BN[0][0] __________________________________________________________________________________________________block_10_depthwise (DepthwiseCo (None, 10, 10, 384) 3456 block_10_expand_relu[0][0] __________________________________________________________________________________________________block_10_depthwise_BN (BatchNor (None, 10, 10, 384) 1536 block_10_depthwise[0][0] __________________________________________________________________________________________________block_10_depthwise_relu (ReLU) (None, 10, 10, 384) 0 block_10_depthwise_BN[0][0]__________________________________________________________________________________________________block_10_project (Conv2D) (None, 10, 10, 96) 36864 block_10_depthwise_relu[0][0] __________________________________________________________________________________________________block_10_project_BN (BatchNorma (None, 10, 10, 96) 384 block_10_project[0][0] __________________________________________________________________________________________________block_11_expand (Conv2D) (None, 10, 10, 576) 55296 block_10_project_BN[0][0] __________________________________________________________________________________________________block_11_expand_BN (BatchNormal (None, 10, 10, 576) 2304 block_11_expand[0][0] __________________________________________________________________________________________________block_11_expand_relu (ReLU)(None, 10, 10, 576) 0 block_11_expand_BN[0][0] __________________________________________________________________________________________________block_11_depthwise (DepthwiseCo (None, 10, 10, 576) 5184 block_11_expand_relu[0][0] __________________________________________________________________________________________________block_11_depthwise_BN (BatchNor (None, 10, 10, 576) 2304 block_11_depthwise[0][0] __________________________________________________________________________________________________block_11_depthwise_relu (ReLU) (None, 10, 10, 576) 0 block_11_depthwise_BN[0][0]__________________________________________________________________________________________________block_11_project (Conv2D) (None, 10, 10, 96) 55296 block_11_depthwise_relu[0][0] __________________________________________________________________________________________________block_11_project_BN (BatchNorma (None, 10, 10, 96) 384 block_11_project[0][0] __________________________________________________________________________________________________block_11_add (Add) (None, 10, 10, 96) 0 block_10_project_BN[0][0] block_11_project_BN[0][0] __________________________________________________________________________________________________block_12_expand (Conv2D) (None, 10, 10, 576) 55296 block_11_add[0][0]__________________________________________________________________________________________________block_12_expand_BN (BatchNormal (None, 10, 10, 576) 2304 block_12_expand[0][0] __________________________________________________________________________________________________block_12_expand_relu (ReLU)(None, 10, 10, 576) 0 block_12_expand_BN[0][0] __________________________________________________________________________________________________block_12_depthwise (DepthwiseCo (None, 10, 10, 576) 5184 block_12_expand_relu[0][0] __________________________________________________________________________________________________block_12_depthwise_BN (BatchNor (None, 10, 10, 576) 2304 block_12_depthwise[0][0] __________________________________________________________________________________________________block_12_depthwise_relu (ReLU) (None, 10, 10, 576) 0 block_12_depthwise_BN[0][0]__________________________________________________________________________________________________block_12_project (Conv2D) (None, 10, 10, 96) 55296 block_12_depthwise_relu[0][0] __________________________________________________________________________________________________block_12_project_BN (BatchNorma (None, 10, 10, 96) 384 block_12_project[0][0] __________________________________________________________________________________________________block_12_add (Add) (None, 10, 10, 96) 0 block_11_add[0][0]block_12_project_BN[0][0] __________________________________________________________________________________________________block_13_expand (Conv2D) (None, 10, 10, 576) 55296 block_12_add[0][0]__________________________________________________________________________________________________block_13_expand_BN (BatchNormal (None, 10, 10, 576) 2304 block_13_expand[0][0] __________________________________________________________________________________________________block_13_expand_relu (ReLU)(None, 10, 10, 576) 0 block_13_expand_BN[0][0] __________________________________________________________________________________________________block_13_pad (ZeroPadding2D) (None, 11, 11, 576) 0 block_13_expand_relu[0][0] __________________________________________________________________________________________________block_13_depthwise (DepthwiseCo (None, 5, 5, 576) 5184 block_13_pad[0][0]__________________________________________________________________________________________________block_13_depthwise_BN (BatchNor (None, 5, 5, 576) 2304 block_13_depthwise[0][0] __________________________________________________________________________________________________block_13_depthwise_relu (ReLU) (None, 5, 5, 576) 0 block_13_depthwise_BN[0][0]__________________________________________________________________________________________________block_13_project (Conv2D) (None, 5, 5, 160) 92160 block_13_depthwise_relu[0][0] __________________________________________________________________________________________________block_13_project_BN (BatchNorma (None, 5, 5, 160) 640 block_13_project[0][0] __________________________________________________________________________________________________block_14_expand (Conv2D) (None, 5, 5, 960) 153600block_13_project_BN[0][0] __________________________________________________________________________________________________block_14_expand_BN (BatchNormal (None, 5, 5, 960) 3840 block_14_expand[0][0] __________________________________________________________________________________________________block_14_expand_relu (ReLU)(None, 5, 5, 960) 0 block_14_expand_BN[0][0] __________________________________________________________________________________________________block_14_depthwise (DepthwiseCo (None, 5, 5, 960) 8640 block_14_expand_relu[0][0] __________________________________________________________________________________________________block_14_depthwise_BN (BatchNor (None, 5, 5, 960) 3840 block_14_depthwise[0][0] __________________________________________________________________________________________________block_14_depthwise_relu (ReLU) (None, 5, 5, 960) 0 block_14_depthwise_BN[0][0]__________________________________________________________________________________________________block_14_project (Conv2D) (None, 5, 5, 160) 153600block_14_depthwise_relu[0][0] __________________________________________________________________________________________________block_14_project_BN (BatchNorma (None, 5, 5, 160) 640 block_14_project[0][0] __________________________________________________________________________________________________block_14_add (Add) (None, 5, 5, 160) 0 block_13_project_BN[0][0] block_14_project_BN[0][0] __________________________________________________________________________________________________block_15_expand (Conv2D) (None, 5, 5, 960) 153600block_14_add[0][0]__________________________________________________________________________________________________block_15_expand_BN (BatchNormal (None, 5, 5, 960) 3840 block_15_expand[0][0] __________________________________________________________________________________________________block_15_expand_relu (ReLU)(None, 5, 5, 960) 0 block_15_expand_BN[0][0] __________________________________________________________________________________________________block_15_depthwise (DepthwiseCo (None, 5, 5, 960) 8640 block_15_expand_relu[0][0] __________________________________________________________________________________________________block_15_depthwise_BN (BatchNor (None, 5, 5, 960) 3840 block_15_depthwise[0][0] __________________________________________________________________________________________________block_15_depthwise_relu (ReLU) (None, 5, 5, 960) 0 block_15_depthwise_BN[0][0]__________________________________________________________________________________________________block_15_project (Conv2D) (None, 5, 5, 160) 153600block_15_depthwise_relu[0][0] __________________________________________________________________________________________________block_15_project_BN (BatchNorma (None, 5, 5, 160) 640 block_15_project[0][0] __________________________________________________________________________________________________block_15_add (Add) (None, 5, 5, 160) 0 block_14_add[0][0]block_15_project_BN[0][0] __________________________________________________________________________________________________block_16_expand (Conv2D) (None, 5, 5, 960) 153600block_15_add[0][0]__________________________________________________________________________________________________block_16_expand_BN (BatchNormal (None, 5, 5, 960) 3840 block_16_expand[0][0] __________________________________________________________________________________________________block_16_expand_relu (ReLU)(None, 5, 5, 960) 0 block_16_expand_BN[0][0] __________________________________________________________________________________________________block_16_depthwise (DepthwiseCo (None, 5, 5, 960) 8640 block_16_expand_relu[0][0] __________________________________________________________________________________________________block_16_depthwise_BN (BatchNor (None, 5, 5, 960) 3840 block_16_depthwise[0][0] __________________________________________________________________________________________________block_16_depthwise_relu (ReLU) (None, 5, 5, 960) 0 block_16_depthwise_BN[0][0]__________________________________________________________________________________________________block_16_project (Conv2D) (None, 5, 5, 320) 307200block_16_depthwise_relu[0][0] __________________________________________________________________________________________________block_16_project_BN (BatchNorma (None, 5, 5, 320) 1280 block_16_project[0][0] __________________________________________________________________________________________________Conv_1 (Conv2D) (None, 5, 5, 1280) 409600block_16_project_BN[0][0] __________________________________________________________________________________________________Conv_1_bn (BatchNormalization) (None, 5, 5, 1280) 5120 Conv_1[0][0] __________________________________________________________________________________________________out_relu (ReLU) (None, 5, 5, 1280) 0 Conv_1_bn[0][0] ==================================================================================================Total params: 2,257,984Trainable params: 0Non-trainable params: 2,257,984__________________________________________________________________________________________________

3.2 添加分类头

要从特征块生成预测,请使用tf.keras.layers.GlobalAveragePooling2D层在 5x5 空间位置内取平均值,以将特征转换成每个图像一个向量(包含 1280 个元素)。

# 将特征转换成每个图像一个向量(包含1280个元素)global_average_layer = tf.keras.layers.GlobalAveragePooling2D()feature_batch_average = global_average_layer(feature_batch)# print(feature_batch_average.shape) # (32, 1280)

应用tf.keras.layers.Dense层将这些特征转换成每个图像一个预测。您在此处不需要激活函数,因为此预测将被视为 logit 或原始预测值。正数预测 1 类,负数预测 0 类。

prediction_layer = tf.keras.layers.Dense(1) # 用Dense层将这些特征转换成每个图像一个预测prediction_batch = prediction_layer(feature_batch_average)# print(prediction_batch.shape) # (32, 1)

通过使用 Keras 函数式 API 将数据扩充、重新缩放、base_model 和特征提取程序层链接在一起来构建模型。如前面所述,由于我们的模型包含 BatchNormalization 层,因此请使用 training = False。(设置layer.trainable = False时,BatchNormalization 层将以推断模式运行,并且不会更新其均值和方差统计信息。**解冻包含 BatchNormalization 层的模型以进行微调时,应在调用基础模型时通过传递 training = False 来使 BatchNormalization 层保持在推断模式下。**否则,应用于不可训练权重的更新将破坏模型已经学习到的内容。)

# 用Keras函数式API将数据扩充、重新缩放、base_model和特征提取程序层链接在一起来构建模型inputs = tf.keras.Input(shape=(160, 160, 3)) # 统一输入尺寸x = data_augmentation(inputs) # 数据增强x = preprocess_input(x) # 输入预处理x = base_model(x, training=False) # 由于我们的模型包含 BatchNormalization 层,因此使用 training = Falsex = global_average_layer(x) # 转换为每个图像一个向量x = tf.keras.layers.Dropout(0.2)(x) # 使用Dropoutoutputs = prediction_layer(x) # 预测输出值model = tf.keras.Model(inputs, outputs)

3. 编译模型

base_learning_rate = pile(optimizer=tf.keras.optimizers.Adam(lr=base_learning_rate),loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), # 输出层会做normalization(softmax)metrics=['accuracy'])# model.summary()

模型结果如下所示:

Model: "model"_________________________________________________________________Layer (type) Output Shape Param # =================================================================input_2 (InputLayer) [(None, 160, 160, 3)]0 _________________________________________________________________sequential (Sequential)(None, 160, 160, 3) 0 _________________________________________________________________tf.math.truediv (TFOpLambda) (None, 160, 160, 3) 0 _________________________________________________________________tf.math.subtract (TFOpLambda (None, 160, 160, 3) 0 _________________________________________________________________mobilenetv2_1.00_160 (Functi (None, 5, 5, 1280) 2257984 _________________________________________________________________global_average_pooling2d (Gl (None, 1280) 0 _________________________________________________________________dropout (Dropout) (None, 1280) 0 _________________________________________________________________dense (Dense)(None, 1) 1281=================================================================Total params: 2,259,265Trainable params: 1,281Non-trainable params: 2,257,984_________________________________________________________________

MobileNet 中的 250 万个参数被冻结,但在密集层中有 1200 个可训练参数。它们分为两个tf.Variable对象,即权重和偏差。

print(len(model.trainable_variables)) # 2

4. 训练模型

经过 10 个周期的训练后,您应该在验证集上看到约 95% 的准确率。

loss0, acc0 = model.evaluate(validation_dataset)# print("initial loss: {:.2f}".format(loss0))# print("initial accuracy: {:.2f}".format(acc0))

其输出结果为:

26/26 [==============================] - 3s 69ms/step - loss: 0.9336 - accuracy: 0.4220initial loss: 0.93initial accuracy: 0.42

初始模型训练的准确度为42%,经过10个epochs后,我们观察其效果:

initial_epochs = 10history = model.fit(train_dataset,epochs=initial_epochs,validation_data=validation_dataset)

我们可以看到准确率提升至95%左右。

我们用plt画出使用 MobileNet V2 基础模型作为固定特征提取程序时训练和验证准确率/损失的学习曲线。

acc = history.history['accuracy']val_acc = history.history['val_accuracy']loss = history.history['loss']val_loss = history.history['val_loss']plt.figure(figsize=(8, 8))plt.subplot(2, 1, 1)plt.plot(acc, label='Training Accuracy')plt.plot(val_acc, label='Validation Accuracy')plt.legend(loc='lower right')plt.ylabel('Accuracy')plt.ylim([min(plt.ylim()), 1])plt.title('Training and Validation Accuracy')plt.subplot(2, 1, 2)plt.plot(loss, label='Training Loss')plt.plot(val_loss, label='Validation Loss')plt.legend(loc='upper right')plt.ylabel('Cross Entropy')plt.ylim([0, 1.0])plt.title('Training and Validation Loss')plt.xlabel('epoch')plt.show()

其曲线如下所示:

验证指标明显优于训练指标,主要原因是tf.keras.layers.BatchNormalizationtf.keras.layers.Dropout等层会影响训练期间的准确率。在计算验证损失时,它们处于关闭状态。

在较小程度上,这也是因为训练指标报告的是某个周期的平均值,而验证指标则在经过该周期后才进行评估,因此验证指标会看到训练时间略长一些的模型。

5. 微调

在之前的特征提取实验中,我们仅在 MobileNet V2 基础模型的顶部训练了一些层。预训练网络的权重在训练过程中未更新。

**进一步提高性能的一种方式是在训练(或“微调”)预训练模型顶层的权重的同时,另外训练您添加的分类器。**训练过程将强制权重从通用特征映射调整为专门与数据集相关联的特征。

注:只有在使用设置为不可训练的预训练模型训练顶级分类器之后,才能尝试这样做。如果在预训练模型的顶部添加一个随机初始化的分类器并尝试共同训练所有层,则梯度更新的幅度将过大(由于分类器的随机权重所致),这将导致预训练模型忘记它已经学习的内容。

另外,还应尝试微调少量顶层而不是整个 MobileNet 模型。**在大多数卷积网络中,层越高,它的专门程度就越高。前几层学习非常简单且通用的特征,这些特征可以泛化到几乎所有类型的图像。**随着您向上层移动,这些特征越来越特定于训练模型所使用的数据集。微调的目标是使这些专用特征适应新的数据集,而不是覆盖通用学习。

5.1 解冻模型的顶层

解冻base_model并将底层设置为不可训练。随后重新编译模型(使这些更改生效的必需操作),然后恢复训练。

# 解冻模型的顶层base_model.trainable = True# Let's take a look to see how many layers are in the base modelprint("Number of layers in the base model: ", len(base_model.layers))# Fine-tune from this layer onwardsfine_tune_at = 100# Freeze all the layers before the `fine_tune_at` layerfor layer in base_model.layers[:fine_tune_at]:layer.trainable = False

base model的层数为154。

Number of layers in the base model: 154

5.2 编译模型

当我们正在**训练一个大得多的模型并且想要重新调整预训练权重时,在此阶段需使用较低的学习率。**否则,模型可能会很快过拟合。

print("-------------------------Fine Tuning-------------------------")pile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate / 10), # 在训练一个大得多的模型并且想要重新调整预训练权重时使用较低的学习率。metrics=['accuracy'])

model.summary()查看此时的模型结构:

Model: "model"_________________________________________________________________Layer (type) Output Shape Param # =================================================================input_2 (InputLayer) [(None, 160, 160, 3)]0 _________________________________________________________________sequential (Sequential)(None, 160, 160, 3) 0 _________________________________________________________________tf.math.truediv (TFOpLambda) (None, 160, 160, 3) 0 _________________________________________________________________tf.math.subtract (TFOpLambda (None, 160, 160, 3) 0 _________________________________________________________________mobilenetv2_1.00_160 (Functi (None, 5, 5, 1280) 2257984 _________________________________________________________________global_average_pooling2d (Gl (None, 1280) 0 _________________________________________________________________dropout (Dropout) (None, 1280) 0 _________________________________________________________________dense (Dense)(None, 1) 1281=================================================================Total params: 2,259,265Trainable params: 1,862,721Non-trainable params: 396,544_________________________________________________________________

我们看一下现在模型可训练的变量有多少?

print(len(model.trainable_variables)) # 56

5.3 继续训练模型

如果您已提前训练至收敛,则此步骤将使您的准确率提高几个百分点。

fine_tune_epochs = 10total_epochs = initial_epochs + fine_tune_epochshistory_fine = model.fit(train_dataset,epochs=total_epochs,initial_epoch=history.epoch[-1],validation_data=validation_dataset)

经过微调后,模型在验证集上的准确率几乎达到 98%。

在微调 MobileNet V2 基础模型的最后几层并在这些层上训练分类器时,我们来看一下训练和验证准确率/损失的学习曲线。验证损失比训练损失高得多,因此可能存在一些过拟合。

# 将两次训练的学习曲线连起来作图acc += history_fine.history['accuracy']val_acc += history_fine.history['val_accuracy']loss += history_fine.history['loss']val_loss += history_fine.history['val_loss']plt.figure(figsize=(8, 8))plt.subplot(2, 1, 1)plt.plot(acc, label='Training Accuracy')plt.plot(val_acc, label='Validation Accuracy')plt.ylim([0.8, 1])plt.plot([initial_epochs - 1, initial_epochs - 1],plt.ylim(), label='Start Fine Tuning')plt.legend(loc='lower right')plt.title('Training and Validation Accuracy')plt.subplot(2, 1, 2)plt.plot(loss, label='Training Loss')plt.plot(val_loss, label='Validation Loss')plt.ylim([0, 1.0])plt.plot([initial_epochs - 1, initial_epochs - 1],plt.ylim(), label='Start Fine Tuning')plt.legend(loc='upper right')plt.title('Training and Validation Loss')plt.xlabel('epoch')plt.show()

当新的训练集相对较小且与原始 MobileNet V2 数据集相似时,也可能存在一些过拟合。

6. 评估和预测

最后,您可以使用测试集在新数据上验证模型的性能。

loss, accuracy = model.evaluate(test_dataset)# 6/6 [==============================] - 1s 79ms/step - loss: 0.0157 - accuracy: 0.9948

现在,我们可以使用此模型来预测您的宠物是猫还是狗。

#Retrieve a batch of images from the test setimage_batch, label_batch = test_dataset.as_numpy_iterator().next() # as_numpy_iterator()为分批次batch操作predictions = model.predict_on_batch(image_batch).flatten()# Apply a sigmoid since our model returns logitspredictions = tf.nn.sigmoid(predictions)predictions = tf.where(predictions < 0.5, 0, 1) # 值<0.5输出0,值>0.5输出1print('Predictions:\n', predictions.numpy())print('Labels:\n', label_batch)

结果为:

Predictions:[0 0 0 1 0 0 1 1 1 1 0 1 1 1 1 0 1 1 0 1 1 1 1 1 0 0 0 0 0 0 1 0]Labels:[0 0 0 1 0 0 1 1 1 1 0 1 1 1 1 0 1 1 0 1 1 1 1 1 0 0 0 0 0 0 1 0]

我们用标签展示其结果,验证预测是否正确:

plt.figure(figsize=(10, 10))for i in range(9):ax = plt.subplot(3, 3, i + 1)plt.imshow(image_batch[i].astype("uint8"))plt.title(class_names[predictions[i]])plt.axis("off")plt.show()

在这里插入图片描述

7. 总结

**使用预训练模型进行特征提取:**使用小型数据集时,常见做法是利用基于相同域中的较大数据集训练的模型所学习的特征。==为此,您需要实例化预训练模型并在顶部添加一个全连接分类器。预训练模型处于“冻结状态”,训练过程中仅更新分类器的权重。==在这种情况下,卷积基提取了与每个图像关联的所有特征,而您刚刚训练了一个根据给定的提取特征集确定图像类的分类器。

**微调预训练模型:**为了进一步提高性能,可能需要通过微调将预训练模型的顶层重新用于新的数据集。==在本例中,您调整了权重,以使模型学习特定于数据集的高级特征。==当训练数据集较大且与训练预训练模型所使用的原始数据集非常相似时,通常建议使用这种技术。

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。