深度学习网络fine-tune原理研究 - 以卷积神经网络为例 - 郑瀚Andrew
2023-4-26 17:51:0 Author: www.cnblogs.com(查看原文) 阅读量:21 收藏

预训练模型就是已经用数据集训练好了的模型,这里的数据集一般指大型数据集。比如

  • VGG16/19
  • Resnet
  • Imagenet
  • COCO

正常情况下,在图像识别任务中常用的VGG16/19等网络是他人调试好的优秀网络,我们无需再修改其网络结构。

参考资料:

https://zhuanlan.zhihu.com/p/35890660
https://github.com/szagoruyko/loadcaffe

用一个单神经元网络解释模型微调的基本原理,

  • Step1:假设我们的神经网络符合下面的形式:Y = W * X
  • Step2:现在我们要找到一个W,使得当输入X=2时,输出Y=1,也就是希望W=0.5:1 = W * 2
  • Step3:按照神经网络的基本训练过程,首先要对W进行初始化,初始化的值符合均值为0,方差为1的分布,假设W初始化为0.1:Y = 0.1 * X
  • Step4:现在开始训练FP过程,当输入X=2时,W=0.1,输出Y=0.2,这个时候实际值和目标值1的误差为0.8:1 <====== 0.2 = 0.1 * 2
  • Step5:开始BP反向传导,0.8的误差经过反向传播去更新权值W,假如这次更新为W=0.2,输出位0.4,与目标值的误差为0.6:1 <====== 0.4 = 0.2 * 2
  • Step6:可能经过10次或20次BP反向传导,W终于得到了我们想要的0.5:Y = 0.5 * X
  • Step7:如果最开始初始化的时候有人告诉你,W的值应该在0.47附近
  • Step8:那么从最开始训练,你与目标值的误差就只有0.06了,那么可能只要一步两步BP,就能将W调整到0.5:1 <====== 0.94 = 0.47 * 2

Step7就相当于给你一个预训练模型(pre-trained model),Step8就是基于这个预训练模型去微调(fine-tune)。

可以看到,相对于从头开始训练,微调省去了大量计算资源和计算时间,提高了计算效率,甚至提高了准确率(因为在超大规模训练过程中,模型可能陷入局部次优空间中无法跳出,预训练相当于已经探好了最难的一部分路,后面的路下游模型走起来就轻松了)。

细心的读者可能会注意到,预训练模型对下游fine-tune任务效果的好坏,和以下几个因素有关

  • 预训练模型训练所用的语料和下游fine-tune任务的重合度:本质上,预训练模型的模型权重参数,代表的是喂入预训练模型的语料。如果预训练任何和下游fine-tune任务领域相差太大,则预训练模型的参数几乎不能起到提效的帮助,甚至可能帮倒忙。
  • 预训练模型自身的容量:理论上,如果预训练模型足够大,能够包含下游任务的一部分核心部分,则预训练模型可以通过权重重调整,在fine-tune的过程中,激活一部分神经元以及关闭一部分神经元,以此使预训练模型朝着下游任务的方向去“生长”。
  • 预训练模型使用的语料库是否足够大和种类丰富,因为这决定了预训练模型是否完成了足够的预训练,否则如果上游预训练模型没有完成收敛,接入下游fine-tune的时候,预训练模型也依然需要进行大量的微调,这对极大拖慢整体模型的收敛。反之,如果预训练模型已经基本完成了收敛,则对下游fine-tune训练的数据集要求就很小,fine-tune就可以基于一个小数据集依然可以得到较好的效果,同时也仅需要较少的训练时间。
  • 预训练模型输入层的向量化方式、张量维度、嵌入方式、编码方位、shape维度等等,和下游fine-tune任务的这些参数结构是否完全一致(或者是否具备一定的迁移性),理论上说,输入层的结构是一种特征工程的经验形式,它本身也代表了模型对目标任务的某种抽象。打个比方,用于文本生成任务的模型,如果将一个像素图片“强行转换适配”输入进去,最终训练和预测的效果都不会好。

卷积神经网络的核心是:

  • 浅层卷积层提取基础特征,比如边缘,轮廓等基础特征
  • 深层卷积层提取抽象特征,比如整个脸型
  • 全连接层根据特征组合进行评分分类

使用大型数据集训练的预训练模型,已经具备了提取浅层基础特征和深层抽象特征的能力。相比不做微调,这种方法具备以下优势:

  • 避免了从头开始训练,减少了训练时间,节省了计算资源
  • 避免了模型不收敛、参数不够优化、准确率低、模型泛化能力低、容易过拟合等问题

数据集1:数据量少,但数据相似度非常高

在这种情况下,我们所做的只是修改最后几层或最终的softmax图层的输出类别。

数据集2:数据量少,数据相似度低

在这种情况下,我们可以冻结预训练模型的初始层(比如k层),并再次训练剩余的(n-k)层。由于新数据集的相似度较低,因此根据新数据集对较高层进行重新训练具有重要意义。

数据集3 - 数据量大,数据相似度低

在这种情况下,由于我们有一个大的数据集,我们的神经网络训练将会很有效。但是,由于我们的数据与用于训练我们的预训练模型的数据相比有很大不同。使用预训练模型进行的预测不会有效。因此,最好根据你的数据从头开始训练神经网络(Training from scatch)。

数据集4:数据量大,数据相似度高

这是理想情况。在这种情况下,预训练模型应该是最有效的。使用模型的最好方法是保留模型的体系结构和模型的初始权重。然后,我们可以使用在预先训练的模型中的权重来重新训练该模型。

  • 通常的做法是截断预先训练好的网络的最后一层(softmax层),并用与我们自己的问题相关的新的softmax层替换它。例如,ImageNet上预先训练好的网络带有1000个类别的softmax图层。如果我们的任务是对10个类别的分类,则网络的新softmax层将由10个类别组成,而不是1000个类别。然后,我们在网络上运行预先训练的权重。确保执行交叉验证,以便网络能够很好地推广。
  • 使用较小的学习率来训练网络。由于我们预计预先训练的权重相对于随机初始化的权重已经相当不错,我们不想过快地扭曲它们太多。通常的做法是使初始学习率比用于从头开始训练(Training from scratch)的初始学习率小10倍。
  • 如果数据集数量过少,我们进来只训练最后一层,如果数据集数量中等,冻结预训练网络的前几层的权重也是一种常见做法。这是因为前几个图层捕捉了与我们的新问题相关的通用特征,如曲线和边。我们希望保持这些权重不变。相反,我们会让网络专注于学习后续深层中特定于数据集的特征。

常见的预训练分类网络有牛津的VGG模型、谷歌的Inception模型、微软的ResNet模型等,他们都是预训练的用于分类和检测的卷积神经网络(CNN)。

本次选用的是VGG16模型,是一个在ImageNet数据集上预训练的模型,分类性能优秀,对其他数据集适应能力优秀。

0x1:直接基于VGG16进行手写数字预测

from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.vgg16 import preprocess_input, decode_predictions
import numpy as np

model = VGG16(weights='imagenet')

img_path = '6.webp'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)

preds = model.predict(x)
# decode the results into a list of tuples (class, description, probability)
# (one such list for each sample in the batch)
print('Predicted:', decode_predictions(preds, top=3)[0])

6.webp

输出结果:

Predicted: [('n03532672', 'hook', 0.4591384), ('n02910353', 'buckle', 0.032941677), ('n01930112', 'nematode', 0.032439113)]

可以看到,VGG16输出的最高概率预测结果是hook,很明显,VGG16的训练集并没有关于数字图片的样本。

0x2:通过手写数字,可视化VGG16各个层参数

from keras.models import Model
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.vgg16 import preprocess_input, decode_predictions
import numpy as np
import cv2
import matplotlib.pyplot as plt


def vis_conv(images, n, name, t):
    """visualize conv output and conv filter.
    Args:
           img: original image.
           n: number of col and row.
           t: vis type.
           name: save name.
    """
    size = 64
    margin = 5

    if t == 'filter':
        results = np.zeros((n * size + 7 * margin, n * size + 7 * margin, 3))
    if t == 'conv':
        results = np.zeros((n * size + 7 * margin, n * size + 7 * margin))

    for i in range(n):
        for j in range(n):
            if t == 'filter':
                filter_img = images[i + (j * n)]
            if t == 'conv':
                filter_img = images[..., i + (j * n)]
            filter_img = cv2.resize(filter_img, (size, size))

            # Put the result in the square `(i, j)` of the results grid
            horizontal_start = i * size + i * margin
            horizontal_end = horizontal_start + size
            vertical_start = j * size + j * margin
            vertical_end = vertical_start + size
            if t == 'filter':
                results[horizontal_start: horizontal_end, vertical_start: vertical_end, :] = filter_img
            if t == 'conv':
                results[horizontal_start: horizontal_end, vertical_start: vertical_end] = filter_img

    # Display the results grid
    plt.imshow(results)
    plt.savefig('images/{}_{}.jpg'.format(t, name), dpi=600)
    plt.show()


def conv_output(model, layer_name, img):
    """Get the output of conv layer.

    Args:
           model: keras model.
           layer_name: name of layer in the model.
           img: processed input image.

    Returns:
           intermediate_output: feature map.
    """
    # this is the placeholder for the input images
    input_img = model.input

    try:
        # this is the placeholder for the conv output
        out_conv = model.get_layer(layer_name).output
    except:
        raise Exception('Not layer named {}!'.format(layer_name))

    # get the intermediate layer model
    intermediate_layer_model = Model(inputs=input_img, outputs=out_conv)

    # get the output of intermediate layer model
    intermediate_output = intermediate_layer_model.predict(img)

    return intermediate_output[0]


if __name__ == '__main__':
    model = VGG16(weights='imagenet')

    img_path = '6.webp'
    img = image.load_img(img_path, target_size=(224, 224))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)

    preds = model.predict(x)
    # decode the results into a list of tuples (class, description, probability)
    # (one such list for each sample in the batch)
    print('Predicted:', decode_predictions(preds, top=3)[0])

    conv_output_block1_conv1 = conv_output(model, "block1_conv1", x)
    print("block1_conv1: ", conv_output_block1_conv1)
    vis_conv(conv_output_block1_conv1, 8, "block1_conv1", 'conv')

    conv_output_block1_conv2 = conv_output(model, "block1_conv2", x)
    print("block1_conv2: ", conv_output_block1_conv2)
    vis_conv(conv_output_block1_conv2, 8, "block1_conv2", 'conv')

    conv_output_block2_conv1 = conv_output(model, "block2_conv1", x)
    print("block2_conv1: ", conv_output_block2_conv1)
    vis_conv(conv_output_block2_conv1, 8, "block2_conv1", 'conv')

    conv_output_block2_conv2 = conv_output(model, "block2_conv2", x)
    print("block2_conv2: ", conv_output_block2_conv2)
    vis_conv(conv_output_block2_conv2, 8, "block2_conv2", 'conv')

    conv_output_block3_conv1 = conv_output(model, "block3_conv1", x)
    print("block3_conv1: ", conv_output_block3_conv1)
    vis_conv(conv_output_block3_conv1, 8, "block3_conv1", 'conv')

    conv_output_block3_conv2 = conv_output(model, "block3_conv2", x)
    print("block3_conv2: ", conv_output_block3_conv2)
    vis_conv(conv_output_block3_conv2, 8, "block3_conv2", 'conv')

    conv_output_block5_conv3 = conv_output(model, "block5_conv3", x)
    print("block5_conv3: ", conv_output_block5_conv3)
    vis_conv(conv_output_block5_conv3, 8, "block5_conv3", 'conv')

    print("fc1: ", conv_output(model, "fc1", x))
    print("fc2: ", conv_output(model, "fc2", x))
    print("predictions: ", conv_output(model, "predictions", x))

输出结果: 

1/1 [==============================] - 2s 2s/step
Predicted: [('n03532672', 'hook', 0.4591384), ('n02910353', 'buckle', 0.032941677), ('n01930112', 'nematode', 0.032439113)]
1/1 [==============================] - 0s 53ms/step
block1_conv1:  [[[  0.         42.11969     0.        ...   0.         32.04823
     0.       ]
  [  0.         46.303555   82.50592   ...   0.        324.38284
   164.56157  ]
  [  0.         46.303555   82.50592   ...   0.        324.38284
   164.56157  ]
  ...
  [  0.         46.303555   82.50592   ...   0.        324.38284
   164.56157  ]
  [  0.         46.303555   82.50592   ...   0.        324.38284
   164.56157  ]
  [  2.61003    32.20762   173.75212   ...   0.        517.4678
   391.77734  ]]

 [[  0.         56.784718    0.        ...   0.          0.
     0.       ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  ...
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [377.4901     38.781555  204.19121   ...   0.        382.94656
   378.29724  ]]

 [[  0.         56.784718    0.        ...   0.          0.
     0.       ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  ...
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [377.4901     38.781555  204.19121   ...   0.        382.94656
   378.29724  ]]

 ...

 [[  0.         56.784718    0.        ...   0.          0.
     0.       ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  ...
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [377.4901     38.781555  204.19121   ...   0.        382.94656
   378.29724  ]]

 [[  0.         56.784718    0.        ...   0.          0.
     0.       ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  ...
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [  2.401019   58.89123    63.275116  ...   0.          2.2158926
    10.105784 ]
  [377.4901     38.781555  204.19121   ...   0.        382.94656
   378.29724  ]]

 [[  0.         39.011864    0.        ...   0.          0.
     0.       ]
  [323.5314     39.00029    97.09346   ...   0.          0.
    67.22728  ]
  [323.5314     39.00029    97.09346   ...   0.          0.
    67.22728  ]
  ...
  [323.5314     39.00029    97.09346   ...   0.          0.
    67.22728  ]
  [323.5314     39.00029    97.09346   ...   0.          0.
    67.22728  ]
  [523.7337     25.070164  184.00014   ...   0.        144.83621
   315.41928  ]]]
1/1 [==============================] - 0s 84ms/step
block1_conv2:  [[[982.48444    65.59724     0.        ...  81.02978   698.99084
   172.65338  ]
  [256.9937    101.16306     8.7225065 ... 203.38603   340.56735
     0.       ]
  [314.77548   126.94779     0.        ... 159.34764   175.0137
     0.       ]
  ...
  [314.77548   126.94779     0.        ... 159.34764   175.0137
     0.       ]
  [ 63.2487      0.          0.        ... 125.09357   413.46884
    33.402287 ]
  [  0.          0.          0.        ...  32.059208    0.
     7.143284 ]]

 [[401.39062    97.3492      0.        ... 134.1313    454.73416
     0.       ]
  [  0.         97.926704  136.89134   ... 259.61768   632.9747
     0.       ]
  [  0.        125.44156    95.91204   ... 174.20306   390.24847
     0.       ]
  ...
  [  0.        125.44156    95.91204   ... 174.20306   390.24847
     0.       ]
  [  0.          0.        109.98622   ... 103.348114  854.354
     0.       ]
  [  0.          0.          0.        ...   0.        394.38068
     0.       ]]

 [[396.95483   167.3767      0.        ...  69.25613   207.11255
     4.1853294]
  [  0.        174.81584    76.58766   ... 161.11617   339.40433
     0.       ]
  [151.61284    87.23442    16.130083  ...   6.742235    1.1302795
     0.       ]
  ...
  [151.61284    87.23442    16.130083  ...   6.742235    1.1302795
     0.       ]
  [  0.          0.         70.19446   ...   0.        479.9812
   254.07501  ]
  [  0.          0.          0.        ...   0.        199.8518
    50.87436  ]]

 ...

 [[396.95483   167.3767      0.        ...  69.25613   207.11255
     4.1853294]
  [  0.        174.81584    76.58766   ... 161.11617   339.40433
     0.       ]
  [151.61284    87.23442    16.130083  ...   6.742235    1.1302795
     0.       ]
  ...
  [151.61284    87.23442    16.130083  ...   6.742235    1.1302795
     0.       ]
  [  0.          0.         70.19446   ...   0.        479.9812
   254.07501  ]
  [  0.          0.          0.        ...   0.        199.8518
    50.87436  ]]

 [[196.74297     0.          0.        ...  76.20704   371.12302
   239.03537  ]
  [  0.          0.         54.11582   ... 132.80391   642.51025
   472.34528  ]
  [  0.          0.          4.422485  ...   7.28855   283.40457
   706.94666  ]
  ...
  [  0.          0.          4.422485  ...   7.28855   283.40457
   706.94666  ]
  [  0.          0.         54.947617  ...   0.        688.73157
   731.2318   ]
  [  0.          0.          0.        ...   0.        364.4021
   284.65625  ]]

 [[  0.          0.          0.        ...   0.          0.
     0.       ]
  [  0.          0.          0.        ...   0.        407.869
     0.       ]
  [  0.          0.          0.        ...   0.        198.98882
   101.46747  ]
  ...
  [  0.          0.          0.        ...   0.        198.98882
   101.46747  ]
  [  0.          0.          0.        ...   0.        534.15466
    69.81046  ]
  [287.62454     0.          0.        ...   0.        764.0485
     0.       ]]]
1/1 [==============================] - 0s 76ms/step
block2_conv1:  [[[   0.          0.        146.08685  ... 1138.9917      0.
   1914.1439  ]
  [   0.          0.        617.18994  ...  630.32166     0.
      0.      ]
  [   0.          0.        479.59012  ...  803.52374     0.
    281.59882 ]
  ...
  [   0.          0.        479.59012  ...  803.52374     0.
    281.59882 ]
  [   0.          0.        583.4128   ...  895.7679      0.
    715.7333  ]
  [   0.          0.       1087.817    ... 2163.6226      0.
      0.      ]]

 [[   0.        657.53296     0.       ...  660.99      461.2479
   1719.0864  ]
  [   0.        823.556     349.60562  ...    0.        542.6992
      0.      ]
  [   0.        748.83795   131.92645  ...   30.981398  517.1108
     82.481895]
  ...
  [   0.        748.83795   131.92645  ...   30.981398  517.1108
     82.481895]
  [   0.        826.5497    252.64777  ...   64.045074  392.9257
    619.41876 ]
  [   0.        693.9135   1073.2073   ... 1989.0895    697.90814
      0.      ]]

 [[   0.        239.73143     0.       ...  901.56885   274.7921
   1343.2406  ]
  [   0.        214.44774   181.45721  ...    0.        279.94656
      0.      ]
  [   0.        130.28665     0.       ...   90.52182   205.50911
    130.00967 ]
  ...
  [   0.        130.28665     0.       ...   90.52182   205.50911
    130.00967 ]
  [   0.        230.28584    60.274647 ...   54.528107   35.845345
    758.34717 ]
  [   0.        283.4764    837.31805  ... 1669.6423    417.16782
    390.9171  ]]

 ...

 [[   0.        239.73143     0.       ...  901.56885   274.7921
   1343.2406  ]
  [   0.        214.44774   181.45721  ...    0.        279.94656
      0.      ]
  [   0.        130.28665     0.       ...   90.52182   205.50911
    130.00967 ]
  ...
  [   0.        130.28665     0.       ...   90.52182   205.50911
    130.00967 ]
  [   0.        230.28584    60.274647 ...   54.528107   35.845345
    758.34717 ]
  [   0.        283.4764    837.31805  ... 1669.6423    417.16782
    390.9171  ]]

 [[   0.        149.2003      0.       ...  467.1346    130.91127
   1713.3496  ]
  [   0.         89.11      283.70944  ...    0.        236.00652
      0.      ]
  [   0.         21.128517   52.216312 ...    0.        233.49413
     93.75622 ]
  ...
  [   0.         21.128517   52.216312 ...    0.        233.49413
     93.75622 ]
  [   0.        120.84711   171.13362  ...    0.         73.68687
    632.3945  ]
  [   0.        207.82211   976.44196  ... 1907.8083    525.08185
     29.64562 ]]

 [[   0.        296.92758   171.61426  ...  975.3303    292.51434
   1616.5455  ]
  [   0.        235.07794   710.6981   ...  276.39038     0.
      0.      ]
  [   0.        116.03024   512.0845   ...  650.45764    53.27237
    331.76382 ]
  ...
  [   0.        116.03024   512.0845   ...  650.45764    53.27237
    331.76382 ]
  [   0.        247.85234   603.1937   ...  753.06476    57.02111
    653.146   ]
  [   0.        435.59036  1229.345    ... 2149.0642    365.4059
      0.      ]]]
WARNING:tensorflow:5 out of the last 5 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7ff36c074790> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 96ms/step
block2_conv2:  [[[  19.134865   65.2908    388.85107  ...   77.567345    0.
      0.      ]
  [ 385.78787     0.         83.92136  ...  823.738       0.
      0.      ]
  [ 362.76718     0.          0.       ...  770.1545      0.
      0.      ]
  ...
  [ 370.19595     0.          0.       ...  693.7316      0.
      0.      ]
  [ 395.07098  1163.4445      0.       ...  685.89105     0.
      0.      ]
  [ 393.64594   221.8914      0.       ...  779.5206      0.
      0.      ]]

 [[   0.          0.        658.96985  ...  266.29254  1334.6693
      0.      ]
  [ 175.15945     0.          0.       ...  927.1358    410.14014
      0.      ]
  [ 113.65867     0.          0.       ...  705.73663   115.82475
    341.95673 ]
  ...
  [  89.81759   278.56213     0.       ...  651.8543    775.20416
    502.7654  ]
  [ 136.82233  1937.8406      0.       ...  647.9445    302.8629
    525.4279  ]
  [ 262.19644   357.42938     0.       ...  750.1874      0.
    489.33453 ]]

 [[   0.          0.        418.21606  ...   12.688118  795.45483
      0.      ]
  [ 234.67218     0.          0.       ...  426.10312     0.
      0.      ]
  [ 145.08507     0.          0.       ...  287.3707      0.
    296.64294 ]
  ...
  [ 103.087685  305.11697    62.120567 ...  267.3017    545.9968
    524.84625 ]
  [ 235.22937  2067.736     239.66722  ...  172.1788    407.2032
    489.35236 ]
  [ 323.7679    407.43408   319.0578   ...  341.47412     0.
    345.82104 ]]

 ...

 [[   0.          0.        580.24994  ...   68.54731   589.51636
      0.      ]
  [ 201.64163     0.        157.14062  ...  501.0832      0.
      0.      ]
  [ 133.07848     0.          0.       ...  351.53003     0.
    415.4161  ]
  ...
  [  86.24023   465.5442     22.741163 ...  337.74213   215.66536
    622.05804 ]
  [ 174.42499  2174.4937     46.142918 ...  286.23798   212.43034
    572.5916  ]
  [ 282.7715    504.28677   132.34572  ...  501.6414      0.
    371.98062 ]]

 [[   0.          0.        247.89134  ...  337.7562    870.8283
      0.      ]
  [ 129.28552     0.          0.       ...  976.0519      0.
      0.      ]
  [   0.        107.290855    0.       ...  696.99493     0.
    248.08282 ]
  ...
  [   0.        545.71716     0.       ...  687.88995   175.53624
    456.3958  ]
  [  40.394768 2056.4695      0.       ...  716.48956   157.10045
    438.98425 ]
  [ 169.84534   324.61182   357.57187  ...  724.79034     0.
    279.55737 ]]

 [[   0.          0.          0.       ...  108.35586  1594.9191
      0.      ]
  [   0.          0.          0.       ...  641.5959    631.3734
      0.      ]
  [   0.          0.          0.       ...  476.7445    236.77658
      0.      ]
  ...
  [   0.          0.          0.       ...  514.81213   659.1744
      0.      ]
  [   0.        558.51337     0.       ...  529.3481    646.179
      0.      ]
  [   0.          0.        318.3686   ...  567.25116     0.
     85.41164 ]]]
WARNING:tensorflow:6 out of the last 6 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7ff36c0250d0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 86ms/step
block3_conv1:  [[[ 104.03467     0.       7676.7437   ...  284.6595    104.21471
    495.0637  ]
  [   0.        313.47745  5637.235    ...  773.39124   312.671
    710.7272  ]
  [   0.        626.0542   4799.9775   ...  797.72      329.52908
    588.2553  ]
  ...
  [   0.        646.8998   4819.6846   ...  770.5025    316.77924
    555.9198  ]
  [   0.        247.11465  5635.4976   ...  528.78986   281.3929
    570.0344  ]
  [  30.971907    0.       7807.489    ...  149.22829   247.03853
    569.7642  ]]

 [[   0.        871.3891   5385.873    ...  138.57967   142.74121
    983.03674 ]
  [   0.       1012.0134    499.58597  ...  162.09428   256.54013
   1158.3336  ]
  [   0.       1021.0573     28.230726 ...  184.61717   219.79193
    785.92285 ]
  ...
  [   0.       1050.2477      0.       ...  146.40399   266.05975
    744.28723 ]
  [   0.        998.98596   374.99014  ...   64.251434  274.85852
    940.72205 ]
  [   0.        715.7788   5695.999    ...  181.9697    113.495964
    998.7362  ]]

 [[   0.        715.5003   4931.604    ...  174.02486   218.0967
    733.6579  ]
  [   0.        782.053      98.879425 ...  201.88213   215.22943
    785.8126  ]
  [   0.        754.37915     0.       ...  156.65364    84.32829
    489.68857 ]
  ...
  [   0.        784.71075     0.       ...  119.22443   137.86731
    454.97656 ]
  [   0.        741.68567     0.       ...   42.16644   243.78513
    592.8224  ]
  [   0.        564.874    5148.9604   ...  128.61302   147.20853
    733.89886 ]]

 ...

 [[   0.        496.68298  4885.435    ...  318.65524   245.03665
    575.7172  ]
  [   0.        486.5161     27.83389  ...  512.1368    232.01933
    566.13635 ]
  [   0.        477.45157     0.       ...  499.04877    68.24934
    263.7914  ]
  ...
  [   0.        499.77722     0.       ...  459.50702   132.83049
    226.18076 ]
  [   0.        439.5999      0.       ...  320.2604    207.68942
    371.76605 ]
  [   0.        339.14404  5100.4336   ...  253.7242    112.67809
    590.3231  ]]

 [[   0.        347.25443  5573.7017   ...  627.8705    275.148
    631.8805  ]
  [   0.        358.82916   292.3079   ...  979.4485    303.31757
    662.5002  ]
  [   0.        478.66336     0.       ... 1011.04913   144.6257
    358.16284 ]
  ...
  [   0.        500.74857     0.       ...  972.3128    223.55475
    336.5134  ]
  [   0.        355.48328   104.18472  ...  832.22375   270.79025
    496.9038  ]
  [   0.        219.11375  5712.4497   ...  539.98773    84.06546
    667.78613 ]]

 [[   0.        604.2773   7762.388    ...  492.06854   294.44586
    373.23422 ]
  [   0.        660.0235   5493.3257   ...  210.03978   176.89102
    304.05936 ]
  [   0.        675.077    4603.5874   ...  169.29701   125.09003
     53.69849 ]
  ...
  [   0.        701.2141   4594.911    ...  142.22992   227.38722
     59.698753]
  [   0.        718.4968   5527.42     ...  161.2458    129.69702
    249.47922 ]
  [   0.        586.12274  8277.507    ...  435.10352     0.
    348.29013 ]]]
1/1 [==============================] - 0s 105ms/step
block3_conv2:  [[[   0.        971.66376   794.8841   ...  172.1506     10.597431
    794.8708  ]
  [   0.        291.7925    826.4213   ...   39.319454    0.
    718.3281  ]
  [   0.        156.54356   802.0568   ...    0.          0.
    503.39447 ]
  ...
  [   0.        401.88135  1241.3585   ...    0.          0.
    362.15497 ]
  [   0.        675.3719   1448.097    ...    0.          9.820769
    410.58932 ]
  [   0.         10.890532  953.4981   ...  233.22906     0.
    579.7396  ]]

 [[ 575.767    1863.7603    592.8948   ...  245.05453     0.
   1068.8091  ]
  [ 514.8801    844.0041    222.19751  ...    0.          0.
    788.1397  ]
  [  19.14704   444.27817   111.57798  ...    0.          0.
    409.57492 ]
  ...
  [ 252.99167   848.908     513.1679   ...    0.          0.
    312.90305 ]
  [ 591.92786  1448.2924    630.19824  ...    0.          0.
    504.8597  ]
  [   0.        379.8196    763.054    ...    0.         72.78092
    733.65424 ]]

 [[ 287.43423  1910.7128    349.80966  ...  387.3527      0.
   1265.1278  ]
  [   0.        740.3088    124.85873  ...    0.          0.
    918.3699  ]
  [   0.        286.83832   118.424774 ...    0.        177.10791
    486.00412 ]
  ...
  [   0.        735.8566    558.8175   ...    0.        193.26689
    449.90454 ]
  [  53.59411  1525.2466    651.7935   ...    0.        103.276146
    716.995   ]
  [   0.        603.8922    836.88104  ...   50.30762   191.5637
    884.57367 ]]

 ...

 [[ 292.4923   1834.398     444.55945  ...  540.1754     14.972595
   1457.0437  ]
  [   0.        642.0181    319.91138  ...   44.719204  156.22743
   1106.5459  ]
  [   0.        170.11359   338.21768  ...    0.        376.42972
    603.82666 ]
  ...
  [   0.        581.471     737.77203  ...    0.        400.47705
    579.5313  ]
  [   0.       1367.3385    798.4122   ...    0.        260.49323
    826.0262  ]
  [   0.        544.79816   826.0728   ...   77.14375   283.54224
    990.2182  ]]

 [[ 584.30676  1950.905     596.8577   ...  740.97327    81.50432
   1820.6097  ]
  [ 116.4952    835.781     588.2435   ...  225.01852   196.70117
   1720.8013  ]
  [   0.        356.6451    615.6922   ...   77.022446  354.97284
   1198.3191  ]
  ...
  [   0.        733.50824  1012.90985  ...    0.        296.32776
   1099.1088  ]
  [ 186.0945   1339.9901   1179.6779   ...    0.        191.2773
   1315.7777  ]
  [   0.        384.91098  1044.8905   ...  228.41646   209.99303
   1404.8423  ]]

 [[ 608.40894  1603.2566    899.59283  ...  999.1029     64.82636
   1448.8973  ]
  [ 733.8801   1092.808     762.7826   ...  444.4963    137.76027
   1666.6692  ]
  [ 293.81265   823.8305    784.1011   ...  267.9691    135.08733
   1363.0045  ]
  ...
  [ 359.85425  1058.6151   1013.9297   ...  163.37076   159.4037
   1266.5629  ]
  [ 682.56195  1274.0765   1125.9093   ...  177.28194   135.8132
   1424.4539  ]
  [ 148.46483   454.86966   954.3874   ...  199.56137   320.5976
   1351.9453  ]]]
1/1 [==============================] - 0s 146ms/step
block5_conv3:  [[[0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  ...
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]]

 [[0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  ...
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]]

 [[0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  ...
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]]

 ...

 [[0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        1.2440066 0.       ]
  ...
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]]

 [[0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  ...
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]]

 [[0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.5987776 0.       ]
  ...
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]
  [0.        0.        0.        ... 0.        0.        0.       ]]]
1/1 [==============================] - 0s 138ms/step
fc1:  [2.9445024  0.         0.         ... 3.411133   0.         0.92348397]
1/1 [==============================] - 0s 143ms/step
fc2:  [0.         0.         0.         ... 0.01487154 0.         0.        ]
1/1 [==============================] - 0s 153ms/step
predictions:  [8.86679481e-06 3.92886477e-06 1.90436788e-06 1.05222316e-05
 2.95820100e-05 1.05888921e-05 8.18475996e-07 2.97174847e-05
 1.72565160e-05 2.33364801e-04 4.08327742e-06 7.54527573e-05
 2.13582698e-05 5.62608966e-06 2.66997467e-05 6.33711488e-06
 5.91164498e-05 2.96048438e-05 1.54325113e-04 1.61606149e-04
 4.36949313e-06 2.27579949e-04 3.22062464e-04 3.87774286e-04
 1.82932072e-05 1.78626142e-04 1.06591207e-04 1.24504077e-04
 7.32575209e-05 1.67771868e-05 1.42938734e-06 5.67790994e-06
 4.40411623e-06 3.51705899e-06 4.54849214e-06 8.11457880e-07
 1.23381051e-06 9.66434072e-07 1.74248067e-04 4.33074547e-06
 3.25646602e-06 1.54293630e-05 1.26219347e-05 1.96861256e-05
 5.53511854e-05 5.12993356e-05 5.80043570e-06 9.02399624e-05
 7.22741834e-06 1.27374151e-05 2.59617846e-05 3.38299797e-05
 2.39712815e-03 1.74616615e-03 4.85557830e-04 3.29024158e-04
 3.13818571e-04 2.12321938e-05 5.26380085e-04 2.88680475e-03
 1.84028875e-03 8.09462945e-05 6.80478770e-05 2.14936007e-02
 3.19925428e-04 1.66888710e-03 2.04587798e-03 3.49455629e-04
 1.27097068e-03 2.72739620e-04 2.91247284e-06 4.48031962e-04
 1.57972545e-06 7.40459245e-06 1.13488693e-06 9.91967318e-06
 1.07233291e-05 1.95666644e-06 1.85278812e-04 3.24987195e-04
 5.19032947e-05 2.71407112e-06 1.49551488e-05 4.88938567e-05
 5.17146982e-05 1.10298810e-04 1.06869438e-05 2.46440832e-05
 4.66025049e-05 3.63443614e-05 1.12969128e-05 2.55341893e-05
 6.81859092e-05 1.14550072e-04 5.32956028e-06 7.00814735e-06
 1.15207583e-03 1.13513424e-05 1.45131880e-05 9.87223932e-04
 1.41433545e-03 1.42524106e-04 1.87630758e-05 1.45034219e-05
 7.02911166e-06 4.97291239e-07 1.44432697e-05 2.14918982e-05
 2.49657751e-04 1.75241279e-04 7.07811036e-04 3.24391127e-02
 2.39375167e-05 4.87557236e-06 1.68786224e-04 1.72599885e-05
 3.57204262e-04 3.65287306e-05 9.62294143e-06 6.66444794e-06
 6.23000451e-05 4.13392627e-05 1.25281904e-05 2.46514765e-06
 1.16787705e-05 4.36370010e-06 1.15267576e-05 1.03567043e-04
 1.90633407e-04 5.81696622e-06 5.24300151e-04 4.73948821e-05
 6.00396706e-05 2.62725348e-06 9.41229882e-06 4.48861829e-05
 3.18245611e-06 1.09500215e-05 3.04656010e-06 2.81243956e-06
 2.78029938e-06 2.36011829e-06 1.19211859e-06 1.40344800e-05
 6.92425092e-05 2.19969384e-04 2.38212277e-04 5.18192837e-06
 6.66403794e-05 1.19804699e-05 1.12324997e-05 7.82153511e-05
 1.48655672e-05 7.67800702e-06 3.79271878e-05 9.57871907e-06
 1.36488925e-05 9.52548271e-06 1.79516901e-05 3.11920776e-05
 1.07268534e-04 2.04860935e-05 1.13185033e-05 4.38859715e-05
 6.85368195e-06 3.27570451e-05 4.21883669e-06 9.13747954e-06
 1.15643152e-05 7.98587553e-06 9.62191461e-06 2.23533661e-05
 1.18041371e-05 4.75581110e-05 6.63245373e-06 3.38082427e-05
 9.82034999e-06 2.01295570e-05 8.89091098e-06 2.23101542e-05
 2.10599119e-05 1.95221619e-05 2.93983067e-05 1.35727038e-04
 4.10272987e-05 9.92941568e-05 8.58638596e-05 4.45206533e-05
 7.41288459e-05 4.27207560e-05 7.12208493e-05 1.87339421e-04
 5.40639439e-06 6.58450299e-05 7.53286349e-06 1.91383544e-04
 1.07185342e-05 3.62894643e-05 1.38193327e-05 3.58770776e-05
 9.85885981e-06 8.50519336e-06 1.47193816e-04 6.64993204e-05
 7.43968712e-06 2.07755093e-05 3.51842573e-05 4.39709947e-06
 3.83616753e-05 2.99786516e-05 1.62636991e-06 5.47050422e-06
 9.95857590e-07 8.05376112e-06 1.96713572e-05 1.18257765e-06
 1.11786721e-05 3.49282709e-05 4.67216933e-06 1.05762056e-05
 5.35382169e-05 1.22479163e-04 8.24888684e-06 2.67953932e-04
 3.17708400e-05 1.71653865e-05 1.05027771e-04 2.14162956e-05
 4.88646037e-06 4.61531381e-05 7.45789384e-06 2.91185825e-05
 5.80204323e-05 8.73349563e-05 3.94712624e-05 4.85797500e-05
 6.84601901e-06 1.49850293e-05 4.85138225e-05 7.45706493e-05
 1.98496113e-04 3.00224547e-05 8.45372233e-06 6.48311516e-06
 6.54547603e-06 3.71917267e-05 2.83854206e-06 1.78560749e-05
 3.07140799e-05 2.26468183e-05 5.00164570e-05 4.60664432e-06
 5.20592039e-05 3.10437244e-05 7.79263937e-05 5.62111791e-06
 1.49219180e-04 6.47040315e-06 5.18403431e-06 2.83422069e-05
 1.08114955e-05 1.53456713e-05 2.45812495e-04 1.05807967e-05
 4.69596816e-05 1.61335429e-05 1.00145635e-05 5.69761096e-06
 1.74532361e-05 1.20673076e-05 9.43993200e-06 5.01738941e-06
 3.85100338e-06 1.40547309e-05 7.89373280e-06 4.30665978e-06
 9.39401434e-06 8.81400138e-06 2.69927250e-06 2.62271114e-05
 8.21756657e-06 1.31640641e-04 1.97637601e-05 6.78912620e-05
 1.72004147e-04 2.91035598e-04 3.54252334e-05 2.54558254e-05
 1.38019350e-05 1.91044728e-06 7.22885125e-06 1.33819249e-05
 7.12421388e-06 7.87766548e-05 1.78281352e-05 3.34753531e-05
 6.08029450e-06 2.98858026e-06 1.37939816e-04 3.45666740e-05
 6.52200970e-06 3.16649130e-05 3.49477432e-06 1.01652977e-05
 8.41250403e-06 7.48465573e-06 1.35816648e-04 7.22609548e-06
 4.39557334e-06 1.19831084e-05 3.40422557e-05 1.52454516e-06
 1.69746852e-06 1.34438051e-06 1.76554167e-05 2.88769229e-06
 4.23087977e-06 1.05430786e-06 5.98303768e-06 4.44874831e-06
 7.20610979e-06 9.38479934e-06 6.35911192e-07 5.10396058e-06
 6.53882182e-07 1.40259897e-06 4.55490772e-06 7.53509375e-05
 9.45165266e-06 4.56607668e-04 3.46355228e-06 3.41798623e-05
 3.84768509e-06 1.31142251e-05 6.59345415e-06 1.28755937e-05
 9.35764911e-06 7.91293678e-06 2.35082607e-05 2.26178645e-06
 3.31025512e-05 2.76681226e-06 1.68231236e-05 6.80708763e-06
 1.29108651e-06 6.85924388e-05 3.70900016e-05 1.71985685e-05
 1.25700643e-03 1.33214565e-03 3.10425255e-02 5.59107903e-05
 3.51523668e-05 4.40640397e-05 1.89676175e-05 5.17027183e-05
 9.10625458e-05 1.45803888e-05 1.62041426e-04 5.17400040e-05
 4.09077838e-05 4.03765211e-04 1.52316759e-04 4.66719284e-05
 2.34392573e-04 1.60122636e-05 4.58906061e-06 6.39632344e-05
 9.06162240e-05 7.67958554e-05 1.55225789e-05 2.62458780e-05
 4.54723631e-05 2.71644458e-05 1.16712208e-05 6.18937993e-05
 4.40446502e-06 1.69388259e-05 4.64936107e-04 1.75527806e-04
 2.73151163e-05 6.96121060e-05 3.32106974e-05 8.41600195e-06
 2.08298861e-05 1.21705219e-04 6.25848115e-05 5.26691438e-05
 3.41659279e-06 1.30620274e-05 7.36525923e-04 4.74398075e-06
 3.45263470e-05 1.00253281e-04 9.23935477e-06 2.03607378e-05
 1.13465694e-05 2.19904769e-06 5.09470337e-05 4.19838540e-03
 1.00290286e-03 2.63983256e-05 2.80405875e-05 2.27962232e-07
 9.34973650e-06 2.28096338e-04 4.37624931e-06 4.99454563e-06
 3.53755640e-05 9.63599712e-04 4.64696450e-06 4.22794583e-05
 2.49279110e-04 1.11948924e-04 4.00889257e-04 2.80806180e-05
 2.20467977e-04 7.32972927e-04 4.86506411e-04 2.13944048e-04
 2.51623533e-05 1.58264316e-04 1.89990387e-04 5.65126655e-04
 1.82046060e-05 1.41215526e-06 5.97492181e-05 2.10429396e-04
 1.14815513e-04 2.95700811e-05 2.83271838e-05 5.36805019e-04
 3.18742881e-04 5.33307139e-05 3.37226847e-05 1.48667343e-04
 7.55067822e-06 1.52780412e-04 2.95972204e-05 1.19778932e-04
 3.52832176e-05 4.95642707e-05 2.11865432e-03 4.00052872e-03
 2.43429913e-05 1.71246738e-05 4.72480082e-04 1.61542965e-04
 1.42520032e-04 3.93152914e-06 2.28453027e-05 5.02332638e-04
 5.61465931e-05 4.19722019e-05 1.03473103e-05 9.32566982e-05
 2.48103228e-04 3.92103073e-04 2.74504127e-05 1.31670722e-05
 8.29012133e-05 2.35334755e-05 4.90546154e-05 6.12018048e-04
 3.29416767e-02 7.38703005e-04 1.45032809e-05 2.26052930e-06
 5.55469996e-05 1.41960825e-03 1.75519352e-04 1.39583615e-04
 5.32880076e-05 1.64087061e-02 9.01359745e-05 3.83946863e-05
 1.97320719e-06 3.78321100e-04 6.72588721e-05 3.71041562e-04
 4.72625870e-05 1.61895136e-04 2.04839933e-04 3.22288433e-05
 3.52817528e-06 7.15582646e-05 4.79896989e-05 3.53601732e-04
 4.54594474e-03 3.57284152e-05 3.91601556e-04 4.97426256e-04
 1.83074051e-04 1.46165185e-05 1.81997917e-03 8.16113879e-06
 3.32378513e-05 1.41442579e-05 6.49202193e-05 1.11072080e-03
 3.96973446e-05 3.17696031e-05 3.51422088e-04 1.33094509e-04
 1.45075168e-03 5.18648769e-04 3.23256850e-02 2.24043634e-02
 8.97353857e-06 1.05607351e-05 1.93923479e-05 1.62865545e-05
 2.40424965e-02 4.11161134e-04 1.48271674e-05 2.35818443e-05
 9.94408154e-04 1.43786694e-03 1.77713620e-04 6.38488700e-06
 2.69750108e-05 2.89386335e-05 7.80405207e-06 6.41705119e-04
 9.40548416e-05 1.12757407e-05 2.28892022e-05 5.97430590e-05
 8.32233782e-05 9.89095061e-05 2.82501249e-04 3.17303883e-03
 3.17591184e-05 2.72919406e-05 3.76993694e-06 7.63166972e-05
 2.03596119e-05 2.04267471e-05 7.24468118e-05 1.95511733e-03
 1.77471829e-05 9.32528783e-05 4.18644668e-05 2.39925605e-04
 7.61114425e-05 1.34542322e-04 1.36341987e-04 2.72285729e-06
 1.63320874e-06 3.19918210e-04 3.58488120e-04 3.70486436e-04
 2.89479376e-05 1.04429608e-04 9.23851803e-06 4.99161706e-06
 4.57598726e-05 4.37971874e-04 1.42190562e-04 7.56013542e-05
 3.04936093e-05 8.39943314e-05 1.95028661e-05 3.22055821e-05
 8.87363876e-06 4.10715120e-06 1.06259424e-04 1.45254788e-04
 3.37117890e-05 5.98966608e-06 7.07039202e-04 2.28137978e-05
 2.17670658e-05 6.64460094e-05 2.68764183e-04 3.21332118e-05
 2.31042814e-05 2.60967878e-04 1.32772921e-05 4.10596476e-06
 5.84332611e-06 6.55371468e-06 1.23988102e-05 6.55802956e-04
 7.23824138e-03 9.32764597e-05 4.86513818e-05 9.33450181e-04
 3.47442023e-04 2.15923501e-04 6.65367479e-05 2.31268095e-05
 1.44284004e-05 1.65621004e-05 8.02202194e-05 4.12447916e-05
 1.73158958e-04 1.83570213e-04 6.93245465e-06 5.82744105e-05
 4.59138393e-01 2.13877211e-05 5.71083569e-04 8.65393588e-07
 1.98091984e-05 5.98172264e-05 1.43234164e-03 2.18751738e-04
 5.04269119e-05 7.31506816e-06 9.18616934e-05 6.62255989e-05
 4.11376823e-06 8.62064189e-05 1.26205377e-05 4.61055140e-04
 1.08992020e-02 2.66485131e-05 2.64866627e-04 6.62679813e-05
 3.60291087e-05 2.66545121e-05 1.77872658e-04 2.01556808e-03
 2.79729593e-05 1.23751124e-05 9.58843448e-04 4.60301017e-05
 3.57670524e-06 1.80370233e-04 4.55380687e-05 7.21158649e-05
 1.30548215e-04 1.50785688e-03 2.28181725e-05 3.10816872e-03
 3.36440653e-03 1.46413358e-05 1.88198217e-04 1.48446697e-05
 5.23523013e-05 2.66233925e-04 7.40830819e-06 1.24792755e-03
 2.70143355e-04 1.57155337e-05 1.91499304e-04 1.47366300e-04
 3.55853881e-05 1.17728428e-04 4.92268661e-03 4.48991996e-05
 1.09024140e-05 3.84566956e-05 6.10373390e-05 1.22622978e-05
 3.02930621e-05 8.43525595e-06 5.89327174e-05 1.77384354e-05
 9.33787192e-07 1.97890895e-05 3.96184361e-04 3.92160400e-05
 2.23727948e-05 1.97188201e-05 1.45821277e-05 8.40431021e-04
 9.53494819e-05 1.51549818e-06 4.29444408e-05 1.63255812e-04
 1.67631064e-04 4.67124803e-04 1.54450056e-04 7.67227630e-06
 1.39268965e-03 1.28351869e-02 5.00910636e-03 2.14553881e-03
 1.01173273e-03 5.63595968e-04 4.91360843e-05 7.19250529e-05
 1.75622830e-04 9.74295926e-06 3.01298278e-04 5.54160670e-06
 1.24025473e-05 1.86115030e-05 5.24135203e-06 4.18825774e-04
 8.15189014e-06 2.72685011e-05 3.91247977e-06 9.30270925e-03
 1.53549627e-04 1.02977538e-05 1.25478473e-05 6.06908216e-05
 1.17585540e-03 1.44420778e-02 5.44897193e-05 1.53933608e-04
 3.76078897e-05 5.28023884e-06 3.16303522e-05 7.72568455e-06
 2.10181301e-04 1.01022335e-04 1.40602220e-04 1.49609783e-04
 1.14452605e-05 2.89457548e-05 3.71720322e-04 2.38283264e-05
 1.56697679e-05 1.03104067e-04 1.83217016e-05 2.14195767e-04
 2.84243783e-04 5.28251330e-05 2.34640265e-05 7.24710208e-06
 1.14483064e-05 3.84075614e-03 7.89254773e-05 3.62368992e-05
 1.83144530e-05 1.27833104e-04 4.90006569e-05 2.73585611e-04
 3.29049872e-05 3.17845872e-04 4.15099430e-06 3.84936793e-05
 2.37875720e-05 2.39650180e-05 5.25766482e-05 3.92098336e-05
 1.74029192e-04 9.73390524e-06 4.05609608e-05 2.88089959e-05
 1.40124266e-05 7.07016588e-05 2.67811352e-04 1.82182499e-04
 1.04057754e-03 2.22881761e-04 2.29549150e-05 1.89316197e-05
 8.92643220e-06 4.58891445e-05 3.33298551e-04 1.64505072e-05
 1.24487444e-04 5.65690698e-06 1.05001331e-04 9.23672560e-05
 1.13163114e-05 3.78826895e-04 2.27822075e-05 1.01369282e-04
 6.74335679e-05 3.10279633e-04 9.25418772e-06 1.27698237e-04
 3.26955749e-04 1.14762376e-03 2.19624781e-04 6.66490960e-05
 2.14133486e-02 3.51987910e-05 8.94156867e-04 2.21527134e-05
 4.55056172e-04 3.78276734e-03 2.22083996e-04 1.75435252e-06
 6.83424514e-06 5.30135003e-06 9.82568872e-06 7.31593987e-04
 5.02913608e-04 1.05920284e-04 1.37225297e-05 4.70397354e-05
 3.52310999e-05 1.94082645e-06 1.88696606e-03 9.90767949e-05
 2.96163809e-04 5.25712385e-05 2.51091842e-04 3.75009346e-04
 1.63949630e-03 4.90943727e-04 1.16265301e-05 5.67344978e-06
 2.31106173e-06 3.50021315e-03 4.65841076e-05 1.78817288e-06
 8.85085377e-04 9.18242076e-05 7.04456979e-05 3.29272407e-05
 2.82556066e-05 4.24005484e-05 7.68456357e-06 3.95997573e-04
 1.53113469e-05 8.22980510e-05 1.13864508e-05 5.75939293e-06
 1.57799313e-05 4.28730937e-06 1.95369706e-03 2.18281384e-05
 2.45123647e-06 2.31460072e-05 3.70571047e-06 4.15719463e-04
 1.83777098e-04 8.83020984e-05 6.35228772e-03 2.97277846e-04
 4.78334114e-04 2.54291444e-06 7.86322809e-04 1.68983519e-04
 4.10227221e-05 1.39408348e-05 1.27657084e-04 1.54425681e-04
 6.44958694e-04 4.67360951e-04 3.12464399e-04 1.91629195e-04
 1.59293544e-04 9.02580359e-05 4.43688259e-05 4.75798175e-03
 3.29969975e-04 5.27197990e-05 1.94561470e-03 5.67512780e-06
 5.76647471e-05 2.69354228e-03 6.31902512e-05 2.22443996e-05
 1.21067016e-04 1.64704434e-05 1.82369724e-04 6.47963898e-04
 2.28299050e-05 3.77393553e-05 5.06583950e-04 4.50800035e-05
 1.00449382e-04 3.34154814e-03 3.99357203e-04 2.91576784e-04
 1.50415999e-05 3.74069023e-05 5.49084434e-05 4.58612112e-06
 1.48610940e-04 4.78453649e-06 5.50092373e-04 7.97398843e-06
 1.57916365e-04 2.02938754e-04 5.12932615e-07 7.93720974e-05
 1.12120520e-04 5.25517389e-05 3.85814092e-05 2.26931297e-03
 7.04336446e-04 6.22067500e-06 3.24391597e-03 3.80431273e-04
 3.58487770e-04 1.87326194e-04 9.27425208e-06 2.41902526e-05
 1.49540312e-04 1.11114350e-05 1.33219073e-04 3.20076477e-04
 4.06427571e-04 1.01031001e-04 1.21471225e-04 1.39722342e-05
 1.44775596e-03 9.68599925e-05 3.95861082e-03 4.69980342e-03
 9.22366689e-06 1.76984744e-04 1.27497464e-02 7.51280750e-05
 2.43124009e-06 2.66393145e-05 1.12654387e-04 6.42826344e-05
 1.49543357e-05 1.02014004e-04 1.28567963e-05 4.39918629e-04
 1.48356130e-05 2.16726930e-05 5.48292974e-06 6.84269753e-06
 3.99081706e-04 7.80194241e-05 4.16754920e-05 2.64994364e-04
 5.74018704e-05 1.37182778e-05 1.14159811e-05 1.43100833e-05
 8.88659633e-06 1.06520629e-05 8.85260033e-06 5.77346009e-06
 8.25636380e-06 2.16832796e-05 8.95236644e-06 2.41253983e-05
 1.22884922e-02 4.22514995e-06 1.92215273e-04 1.14893555e-05
 9.26747089e-06 3.80918646e-05 2.01568528e-05 1.82601270e-05
 1.66271129e-05 1.34015800e-05 2.14585361e-05 4.06647559e-05
 4.95586664e-06 9.35641292e-05 1.71769386e-06 2.04159005e-05
 1.20855322e-04 5.78344843e-05 1.86209247e-04 9.47380249e-05
 4.63605829e-06 2.66953939e-05 3.25924228e-03 1.41969331e-05
 5.21058428e-05 4.36145774e-06 2.74305257e-05 5.58478814e-06
 1.86247416e-04 3.99841110e-06 1.27266567e-05 6.94881010e-06
 1.22096226e-05 2.83595796e-06 1.18330085e-04 1.95743614e-05
 4.30598017e-03 4.46661434e-05 3.86449283e-05 5.68860676e-04
 4.67448954e-05 4.63131000e-05 9.05528850e-07 9.58958262e-05
 2.17774905e-05 6.33123418e-05 1.37754439e-04 1.22215988e-05
 1.86073139e-05 3.51532503e-06 6.26382916e-06 2.33038532e-04
 2.20870024e-05 2.65913677e-05 2.15374494e-05 7.98078545e-05
 3.53792457e-05 1.82980682e-06 8.29796772e-05 2.25770145e-05
 5.53840528e-06 1.28692284e-06 2.41941045e-04 3.47754917e-06
 3.37785932e-05 7.95326923e-06 3.53732721e-05 8.06681346e-04]

View Code

conv_block1_conv1

conv_block1_conv2

conv_block2_conv1

conv_block2_conv2

conv_block3_conv1

conv_block3_conv2

conv_block5_conv3

0x3:使用VGG16作为预训练模型,fine-tune一个手写数字识别卷积网络,再观察基模型和fine-tune模型的可视化

通过前面的章节我们知道,VGG16预训练模型不太具备手写数字识别能力,也就是说,预训练模型的样本领域和下游任务的样本领域存在较大差异。

在这种情况下,我们通过下面几个实验来更加本质地理解fine-tune在模型层面具体调整了哪些东西。

1、直接设计一个CNN卷积网络进行手写数字识别

# In -> [[Conv2D->relu]*2 -> MaxPool2D -> Dropout]*2 -> Flatten -> Dense -> Dropout -> Out
model = Sequential()

model.add(Conv2D(filters=32, kernel_size=(5, 5), padding='Same',
                 activation='relu', input_shape = (28, 28, 1)))
model.add(Conv2D(filters=32, kernel_size=(5, 5), padding='Same',
                 activation='relu'))
model.add(MaxPool2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Conv2D(filters=64, kernel_size=(3, 3), padding='Same',
                 activation='relu'))
model.add(Conv2D(filters=64, kernel_size=(3, 3), padding='Same',
                 activation='relu'))
model.add(MaxPool2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Dropout(0.25))

model.add(Flatten())
model.add(Dense(256, activation="relu"))
model.add(Dropout(0.5))

上图是一个比较标准、主流的CNN卷积网络,在一个相对不大的训练集上,通过普通的设备进行训练就可以得到较好的效果。 

2、冻结预训练基模型,调整下游fine-tune层参数:VGG16 without toplayer + 全连接层 迁移学习

之所以选择这种模型设计结构,出于以下几点考虑:

  • 尽量复用预训练基模型中可以迁移到下游任务的模型部分:理论上,VGG16已经基本完成了局部->整体视野的卷积核训练,这种能力在手写数字识别中也是可以复用的(存在迁移学习的前提条件),这点通过前面章节可视化卷积核的实验也可以看出。
  • VGG16的顶层结构设计初衷是进行1000种生活常见图形识别,并不适配10类数字识别的任务,所以toplayer需要舍弃,同时在下游接上一个新的神经网络结构,专门用于10类手写数字识别。
  • 由于下游任何和预训练基模型存在任务领域的差异,所以下游fine-tune对训练样本集有一定的要求,因此我们需要对数据进行增强,在原数据集的基础上随机旋转、平移、缩放、产生噪音,从而更好地聚焦于数字特征的提取,而不是数据集本身。
  • VGG16的输入层是RGB 224*224*3彩色像素图片,MNIST手写数字是GRAY 28*28*1灰度像素图片,虽然存在一定差异,但总体上都属于像素图片领域范畴,具有一定的迁移学习的理论基础。
  • 训练过程实际上是在训练下游fine-tune的全连接网络,前面的卷积模型参数是”固化(不可训练的)。

下面代码使用了 keras.applications.vgg16 中的 VGG16,在线获取已有的 VGG16 模型及参数,获取后冻结 VGG16 中的所有参数进行训练。

在这之后添加几层 relu 全连接以及用于多分类的 softmax 全连接,原则上基本参考了通用CNN手写数字识别卷积网络的设计。

stacking之后的新的模型结构如下:

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 48, 48, 3)]       0         
                                                                 
 block1_conv1 (Conv2D)       (None, 48, 48, 64)        1792      
                                                                 
 block1_conv2 (Conv2D)       (None, 48, 48, 64)        36928     
                                                                 
 block1_pool (MaxPooling2D)  (None, 24, 24, 64)        0         
                                                                 
 block2_conv1 (Conv2D)       (None, 24, 24, 128)       73856     
                                                                 
 block2_conv2 (Conv2D)       (None, 24, 24, 128)       147584    
                                                                 
 block2_pool (MaxPooling2D)  (None, 12, 12, 128)       0         
                                                                 
 block3_conv1 (Conv2D)       (None, 12, 12, 256)       295168    
                                                                 
 block3_conv2 (Conv2D)       (None, 12, 12, 256)       590080    
                                                                 
 block3_conv3 (Conv2D)       (None, 12, 12, 256)       590080    
                                                                 
 block3_pool (MaxPooling2D)  (None, 6, 6, 256)         0         
                                                                 
 block4_conv1 (Conv2D)       (None, 6, 6, 512)         1180160   
                                                                 
 block4_conv2 (Conv2D)       (None, 6, 6, 512)         2359808   
                                                                 
 block4_conv3 (Conv2D)       (None, 6, 6, 512)         2359808   
                                                                 
 block4_pool (MaxPooling2D)  (None, 3, 3, 512)         0         
                                                                 
 block5_conv1 (Conv2D)       (None, 3, 3, 512)         2359808   
                                                                 
 block5_conv2 (Conv2D)       (None, 3, 3, 512)         2359808   
                                                                 
 block5_conv3 (Conv2D)       (None, 3, 3, 512)         2359808   
                                                                 
 block5_pool (MaxPooling2D)  (None, 1, 1, 512)         0         
                                                                 
 flatten (Flatten)           (None, 512)               0         
                                                                 
 dense (Dense)               (None, 4096)              2101248   
                                                                 
 dropout (Dropout)           (None, 4096)              0         
                                                                 
 dense_1 (Dense)             (None, 4096)              16781312  
                                                                 
 dropout_1 (Dropout)         (None, 4096)              0         
                                                                 
 dense_2 (Dense)             (None, 10)                40970     
                                                                 
=================================================================
Total params: 33,638,218
Trainable params: 33,638,218
Non-trainable params: 0

训练代码:

from keras.models import Model, load_model
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.preprocessing import image
from tensorflow.keras.backend import reshape, expand_dims, spatial_2d_padding, spatial_3d_padding
from tensorflow.keras.applications.vgg16 import preprocess_input, decode_predictions
from tensorflow.keras.layers import Dense, GlobalAveragePooling1D, Flatten, Dropout, Conv1D
from keras.optimizers import SGD
from keras.datasets import mnist
from keras.utils import to_categorical
import tensorflow as tf
import numpy as np
import cv2
import matplotlib.pyplot as plt

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


def vis_conv(images, n, name, t):
    """visualize conv output and conv filter.
    Args:
           img: original image.
           n: number of col and row.
           t: vis type.
           name: save name.
    """
    size = 64
    margin = 5

    if t == 'filter':
        results = np.zeros((n * size + 7 * margin, n * size + 7 * margin, 3))
    if t == 'conv':
        results = np.zeros((n * size + 7 * margin, n * size + 7 * margin))

    for i in range(n):
        for j in range(n):
            if t == 'filter':
                filter_img = images[i + (j * n)]
            if t == 'conv':
                filter_img = images[..., i + (j * n)]
            filter_img = cv2.resize(filter_img, (size, size))

            # Put the result in the square `(i, j)` of the results grid
            horizontal_start = i * size + i * margin
            horizontal_end = horizontal_start + size
            vertical_start = j * size + j * margin
            vertical_end = vertical_start + size
            if t == 'filter':
                results[horizontal_start: horizontal_end, vertical_start: vertical_end, :] = filter_img
            if t == 'conv':
                results[horizontal_start: horizontal_end, vertical_start: vertical_end] = filter_img

    # Display the results grid
    plt.imshow(results)
    plt.savefig('images/{}_{}.jpg'.format(t, name), dpi=600)
    plt.show()


def conv_output(model, layer_name, img):
    """Get the output of conv layer.

    Args:
           model: keras model.
           layer_name: name of layer in the model.
           img: processed input image.

    Returns:
           intermediate_output: feature map.
    """
    # this is the placeholder for the input images
    input_img = model.input

    try:
        # this is the placeholder for the conv output
        out_conv = model.get_layer(layer_name).output
    except:
        raise Exception('Not layer named {}!'.format(layer_name))

    # get the intermediate layer model
    intermediate_layer_model = Model(inputs=input_img, outputs=out_conv)

    # get the output of intermediate layer model
    intermediate_output = intermediate_layer_model.predict(img)

    return intermediate_output[0]


def get_mnist_data():
    (X_train_data, Y_train_data), (X_test_data, Y_test_data) = mnist.load_data()
    # convert Y label into one-hot
    Y_train_data = to_categorical(Y_train_data)
    Y_test_data = to_categorical(Y_test_data)
    X_train_data = X_train_data.astype('float32') / 255.0
    X_test_data = X_test_data.astype('float32') / 255.0

    # reshape the mnist data in 48*48*3
    X_train_data = expand_dims(X_train_data, axis=-1)
    X_test_data = expand_dims(X_test_data, axis=-1)
    X_train_data = tf.pad(X_train_data, [[0, 0], [2, 18], [2, 18], [1, 1]])
    X_test_data = tf.pad(X_test_data, [[0, 0], [2, 18], [2, 18], [1, 1]])

    # prepare validate/train date
    X_train_val = X_train_data[-2000:, ...]
    X_train_data = X_train_data[:-2000, ...]
    Y_train_val = Y_train_data[-2000:]
    Y_train_data = Y_train_data[:-2000]

    print("np.shape(X_train_data): ", np.shape(X_train_data))
    print("np.shape(X_test_data): ", np.shape(X_test_data))
    print("np.shape(X_train_val): ", np.shape(X_train_val))
    print("np.shape(Y_train_data): ", np.shape(Y_train_data))
    print("np.shape(Y_train_val): ", np.shape(Y_train_val))
    print("Y_train_data[0]: ", Y_train_data[0])

    return (X_train_data, Y_train_data), (X_test_data, Y_test_data), (X_train_val, Y_train_val)


def train_fine_tune():
    # create the base pre-trained model
    base_model = VGG16(weights='imagenet', include_top=False, input_shape=(48, 48, 3))

    # create a fine-tune model
    x = base_model.output
    print("base_model.input_shape: ", base_model.input_shape)
    print("base_model.input_shape[1:]: ", base_model.input_shape[1:])
    print("base_model.output_shape: ", base_model.output_shape)
    print("base_model.output_shape[1:]: ", base_model.output_shape[1:])
    # let's add a fully-connected layer
    x = Flatten()(x)
    x = Dense(4096, activation='relu')(x)
    x = Dropout(0.5)(x)
    x = Dense(4096, activation='relu')(x)
    x = Dropout(0.5)(x)
    # and a logistic layer -- let's say we have 10 classes
    predictions = Dense(10, activation='softmax')(x)

    # this is the new model(vgg16+fine-tune model) we will train
    model = Model(inputs=base_model.input, outputs=predictions)
    print("model.input_shape: ", model.input_shape)
    print("model.input_shape[1:]: ", model.input_shape[1:])
    print("model.output_shape: ", model.output_shape)
    print("model.output_shape[1:]: ", model.output_shape[1:])
    model.summary()

    # i.e. freeze all convolutional VGG16 layers
    for layer in base_model.layers:
        layer.trainable = False

    # compile the model (should be done *after* setting layers to non-trainable)
    sgd = SGD(lr=1e-5, decay=1e-6, momentum=0.5, nesterov=True)  # 优化函数,设定学习率(lr)等参数,注意,fine-tune的学习率一般要小于预训练基模型的10倍以下
    # model.compile(loss='categorical_crossentropy', optimizer="rmsprop")
    model.compile(loss='categorical_crossentropy', optimizer=sgd)

    # load the mnist data, and fine-tune the new model
    (X_train_data, Y_train_data), (X_test_data, Y_test_data), (X_train_val, Y_train_val) = get_mnist_data()

    # train the model on the new data for a few epochs
    history = model.fit(
        X_train_data, Y_train_data,
        batch_size=32,
        epochs=20,
        validation_data=(X_train_val, Y_train_val)
    )

    model.save('vgg16_plus_dnn_for_mnist.h5')


def predict():
    model = load_model("./vgg16_plus_dnn_for_mnist.h5")

    (X_train_data, Y_train_data), (X_test_data, Y_test_data), (X_train_val, Y_train_val) = get_mnist_data()

    # 查看第一张图片
    plt.imshow(X_test_data[0])
    plt.show()

    print("前十个图片对应的标签: \n", np.argmax(Y_test_data[:10], axis=1))
    print("取前十张图片测试集预测:\n", np.argmax(model.predict(X_test_data[:10]), axis=1))   


if __name__ == '__main__':
    train_fine_tune()
    # predict()

加载fine-tune好的模型,并对测试集进行测试,

前十个图片对应的标签: 
 [7 2 1 0 4 1 4 9 5 9]

取前十张图片测试集预测:
 [1 2 1 0 4 1 4 9 3 3]

预测错误的图片(索引:0、8、9)如下:

模型预测错误成了:1

模型预测错误成了:1

模型预测错误成了:3

可以看到,在预训练基模型参数不变的情况下进行fine-tune,如果预训练基模型的编码方式、张量维度、张量大小等因为和下游任务不完全一样,则最终迁移学习的效果并不能达到最好。

3、允许预训练基模型参数微调,同时调整下游fine-tune层参数:VGG16 without toplayer + 全连接层 迁移学习

现在我们调整一下策略:

  • VGG16预训练基模型虽然也是卷积神经网络,训练过程也充分拟合了,但是训练样本库并不包含手写数字图片,因此可以推断VGG16并不包含手写数字识别的“知识”
  • 基于VGG16+fine-tune进行手写数字识别,本质上是一个迁移学习,因此基模型本身也需要进行一定的学习调整
from keras.models import Model, load_model
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.preprocessing import image
from tensorflow.keras.backend import reshape, expand_dims, spatial_2d_padding, spatial_3d_padding
from tensorflow.keras.applications.vgg16 import preprocess_input, decode_predictions
from tensorflow.keras.layers import Dense, GlobalAveragePooling1D, Flatten, Dropout, Conv1D
from keras.optimizers import SGD
from keras.datasets import mnist
from keras.utils import to_categorical
import tensorflow as tf
import numpy as np
import cv2
import matplotlib.pyplot as plt

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


def vis_conv(images, n, name, t):
    """visualize conv output and conv filter.
    Args:
           img: original image.
           n: number of col and row.
           t: vis type.
           name: save name.
    """
    size = 64
    margin = 5

    if t == 'filter':
        results = np.zeros((n * size + 7 * margin, n * size + 7 * margin, 3))
    if t == 'conv':
        results = np.zeros((n * size + 7 * margin, n * size + 7 * margin))

    for i in range(n):
        for j in range(n):
            if t == 'filter':
                filter_img = images[i + (j * n)]
            if t == 'conv':
                filter_img = images[..., i + (j * n)]
            filter_img = cv2.resize(filter_img, (size, size))

            # Put the result in the square `(i, j)` of the results grid
            horizontal_start = i * size + i * margin
            horizontal_end = horizontal_start + size
            vertical_start = j * size + j * margin
            vertical_end = vertical_start + size
            if t == 'filter':
                results[horizontal_start: horizontal_end, vertical_start: vertical_end, :] = filter_img
            if t == 'conv':
                results[horizontal_start: horizontal_end, vertical_start: vertical_end] = filter_img

    # Display the results grid
    plt.imshow(results)
    plt.savefig('images/{}_{}.jpg'.format(t, name), dpi=600)
    plt.show()


def conv_output(model, layer_name, img):
    """Get the output of conv layer.

    Args:
           model: keras model.
           layer_name: name of layer in the model.
           img: processed input image.

    Returns:
           intermediate_output: feature map.
    """
    # this is the placeholder for the input images
    input_img = model.input

    try:
        # this is the placeholder for the conv output
        out_conv = model.get_layer(layer_name).output
    except:
        raise Exception('Not layer named {}!'.format(layer_name))

    # get the intermediate layer model
    intermediate_layer_model = Model(inputs=input_img, outputs=out_conv)

    # get the output of intermediate layer model
    intermediate_output = intermediate_layer_model.predict(img)

    return intermediate_output[0]


def get_mnist_data():
    (X_train_data, Y_train_data), (X_test_data, Y_test_data) = mnist.load_data()
    # convert Y label into one-hot
    Y_train_data = to_categorical(Y_train_data)
    Y_test_data = to_categorical(Y_test_data)
    X_train_data = X_train_data.astype('float32') / 255.0
    X_test_data = X_test_data.astype('float32') / 255.0

    # reshape the mnist data in 48*48*3
    X_train_data = expand_dims(X_train_data, axis=-1)
    X_test_data = expand_dims(X_test_data, axis=-1)
    X_train_data = tf.pad(X_train_data, [[0, 0], [2, 18], [2, 18], [1, 1]])
    X_test_data = tf.pad(X_test_data, [[0, 0], [2, 18], [2, 18], [1, 1]])

    # prepare validate/train date
    X_train_val = X_train_data[-2000:, ...]
    X_train_data = X_train_data[:-2000, ...]
    Y_train_val = Y_train_data[-2000:]
    Y_train_data = Y_train_data[:-2000]

    print("np.shape(X_train_data): ", np.shape(X_train_data))
    print("np.shape(X_test_data): ", np.shape(X_test_data))
    print("np.shape(X_train_val): ", np.shape(X_train_val))
    print("np.shape(Y_train_data): ", np.shape(Y_train_data))
    print("np.shape(Y_train_val): ", np.shape(Y_train_val))
    print("Y_train_data[0]: ", Y_train_data[0])

    return (X_train_data, Y_train_data), (X_test_data, Y_test_data), (X_train_val, Y_train_val)


def train_fine_tune(base_model_freeze=True):
    # create the base pre-trained model
    base_model = VGG16(weights='imagenet', include_top=False, input_shape=(48, 48, 3))

    # create a fine-tune model
    x = base_model.output
    print("base_model.input_shape: ", base_model.input_shape)
    print("base_model.input_shape[1:]: ", base_model.input_shape[1:])
    print("base_model.output_shape: ", base_model.output_shape)
    print("base_model.output_shape[1:]: ", base_model.output_shape[1:])
    # let's add a fully-connected layer
    x = Flatten()(x)
    x = Dense(4096, activation='relu')(x)
    x = Dropout(0.5)(x)
    x = Dense(4096, activation='relu')(x)
    x = Dropout(0.5)(x)
    # and a logistic layer -- let's say we have 10 classes
    predictions = Dense(10, activation='softmax')(x)

    # this is the new model(vgg16+fine-tune model) we will train
    model = Model(inputs=base_model.input, outputs=predictions)
    print("model.input_shape: ", model.input_shape)
    print("model.input_shape[1:]: ", model.input_shape[1:])
    print("model.output_shape: ", model.output_shape)
    print("model.output_shape[1:]: ", model.output_shape[1:])
    model.summary()

    if base_model_freeze:
        # i.e. freeze all convolutional VGG16 layers
        for layer in base_model.layers:
            layer.trainable = False
    else:
        for layer in base_model.layers:
            layer.trainable = True

    # compile the model (should be done *after* setting layers to non-trainable)
    sgd = SGD(lr=1e-5, decay=1e-6, momentum=0.5, nesterov=True)  # 优化函数,设定学习率(lr)等参数,注意,fine-tune的学习率一般要小于预训练基模型的10倍以下
    # model.compile(loss='categorical_crossentropy', optimizer="rmsprop")
    model.compile(loss='categorical_crossentropy', optimizer=sgd)

    # load the mnist data, and fine-tune the new model
    (X_train_data, Y_train_data), (X_test_data, Y_test_data), (X_train_val, Y_train_val) = get_mnist_data()

    # train the model on the new data for a few epochs
    history = model.fit(
        X_train_data, Y_train_data,
        batch_size=32,
        epochs=20,
        validation_data=(X_train_val, Y_train_val)
    )

    if base_model_freeze:
        model.save('vgg16_plus_dnn_for_mnist_base_model_freeze.h5')
    else:
        model.save('vgg16_plus_dnn_for_mnist_base_model_train.h5')


def predict():
    model = load_model("./vgg16_plus_dnn_for_mnist.h5")

    (X_train_data, Y_train_data), (X_test_data, Y_test_data), (X_train_val, Y_train_val) = get_mnist_data()

    # 查看图片
    plt.imshow(X_test_data[0])
    plt.show()
    plt.imshow(X_test_data[8])
    plt.show()
    plt.imshow(X_test_data[9])
    plt.show()

    print("前十个图片对应的标签: \n", np.argmax(Y_test_data[:10], axis=1))
    print("取前十张图片测试集预测:\n", np.argmax(model.predict(X_test_data[:10]), axis=1))


def visual_cnnkernel_with_number_6():
    img_path = '6.webp'
    img = image.load_img(img_path, target_size=(224, 224))
    plt.imshow(img)
    plt.show()

    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    print("np.shape(x): ", np.shape(x))


if __name__ == '__main__':
    train_fine_tune(base_model_freeze=False)
    # predict()
    # visual_cnnkernel_with_number_6()

可以看到,相比freeze预训练基模型,允许预训练基模型参数微调后,训练过程的val_loss和loss都开始了稳步下降,这说明了预训练基模型+fine-tune模型整体都在不断朝着你和手写数字识别的方向进行优化,同时预训练基模型自身也需要针对手写数字识别这个下游场景进行再学习和微调。

反过来讲,读者可以注意上之前的章节中freeze了基模型的参数,val_loss和loss始终下不来,模型呈现出了一种欠拟合的状态。

加载fine-tune好的模型,并对测试集进行测试,

前十个图片对应的标签: 
 [7 2 1 0 4 1 4 9 5 9]

取前十张图片测试集预测:
 [7 2 1 0 4 1 4 9 5 9]

可以看到,允许了预训练基模型的参数调整后,在mnist数据集上进行fine-tune,得到的最终模型,很好地适配了新的识别任务。

可视化基模型卷积核,观察基模型卷积核在fine-tune前后的变化。

conv_block1_conv1

conv_block1_conv2

conv_block2_conv1

conv_block2_conv2

conv_block3_conv1

conv_block3_conv2

conv_block5_conv3

可以看到,经过fine-tune之后,基模型的卷积核的感知野和感知强度变大了,这也是为什么fine-tune之后识别mnist手写数字识别能力变强的原因。

0x4:用于下游任务的样本量大小对fine-tune效果的影响 

参考链接:

https://arxiv.org/pdf/1409.1556.pdf 
https://www.cnblogs.com/LittleHann/p/6792511.html 
https://github.com/xiaochus/VisualizationCNN/blob/master/vis.py 
https://www.quora.com/What-is-the-VGG-neural-network 
https://blog.csdn.net/wphkadn/article/details/86772708
https://zhuanlan.zhihu.com/p/483559442 
https://keras.io/api/applications/
https://blog.csdn.net/newlw/article/details/126127251 

文章来源: https://www.cnblogs.com/LittleHann/p/17354069.html
如有侵权请联系:admin#unsafe.sh