码迷,mamicode.com
首页 > Web开发 > 详细

Keras猫狗大战八:resnet50预训练模型迁移学习,图片先做归一化预处理,精度提高到97.5%

时间:2019-12-07 22:52:22      阅读:237      评论:0      收藏:0      [点我收藏+]

标签:imp   height   sha   shape   数据   center   class   use   normal   

上一篇的基础上,对数据调用keras图片预处理函数preprocess_input做归一化预处理,进行训练。

导入preprocess_input:

import os

from keras import layers, optimizers, models
from keras.applications.resnet50 import ResNet50, preprocess_input
from keras.layers import *    
from keras.models import Model

数据生成添加preprocessing_function=preprocess_input

from keras.preprocessing.image import ImageDataGenerator

batch_size = 64

train_datagen = ImageDataGenerator(
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    vertical_flip=True,
    preprocessing_function=preprocess_input)

test_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)


train_generator = train_datagen.flow_from_directory(
        # This is the target directory
        train_dir,
        # All images will be resized to 150x150
        target_size=(150, 150),
        batch_size=batch_size,
        # Since we use binary_crossentropy loss, we need binary labels
        class_mode=binary)

validation_generator = test_datagen.flow_from_directory(
        validation_dir,
        target_size=(150, 150),
        batch_size=batch_size,
        class_mode=binary)

训练25epoch,学习率从1e-3下降到4e-5:

Epoch 1/100
281/281 [==============================] - 152s 540ms/step - loss: 0.2849 - acc: 0.8846 - lr: 0.0010 - val_loss: 0.1195 - val_acc: 0.9694 - val_lr: 0.0010
Epoch 2/100
281/281 [==============================] - 79s 282ms/step - loss: 0.2234 - acc: 0.9079 - lr: 0.0010 - val_loss: 0.1105 - val_acc: 0.9673 - val_lr: 0.0010
Epoch 3/100
281/281 [==============================] - 80s 285ms/step - loss: 0.2070 - acc: 0.9135 - lr: 0.0010 - val_loss: 0.1061 - val_acc: 0.9716 - val_lr: 0.0010
Epoch 4/100
281/281 [==============================] - 80s 283ms/step - loss: 0.1939 - acc: 0.9203 - lr: 0.0010 - val_loss: 0.0998 - val_acc: 0.9748 - val_lr: 0.0010
Epoch 5/100
......
Epoch 22/100
281/281 [==============================] - 80s 284ms/step - loss: 0.1368 - acc: 0.9470 - lr: 4.0000e-05 - val_loss: 0.0943 - val_acc: 0.9777 - val_lr: 4.0000e-05
Epoch 23/100
281/281 [==============================] - 80s 283ms/step - loss: 0.1346 - acc: 0.9479 - lr: 4.0000e-05 - val_loss: 0.1046 - val_acc: 0.9720 - val_lr: 4.0000e-05
Epoch 24/100
281/281 [==============================] - 79s 283ms/step - loss: 0.1320 - acc: 0.9476 - lr: 4.0000e-05 - val_loss: 0.0938 - val_acc: 0.9759 - val_lr: 4.0000e-05
Epoch 25/100
281/281 [==============================] - 79s 282ms/step - loss: 0.1356 - acc: 0.9476 - lr: 4.0000e-05 - val_loss: 0.1063 - val_acc: 0.9745 - val_lr: 4.0000e-05

在测试图片时也需要进行归一化预处理:
def get_input_xy(src=[]):
    pre_x = []
    true_y = []

    class_indices = {cat: 0, dog: 1}

    for s in src:
        input = cv2.imread(s)
        input = cv2.resize(input, (150, 150))
        input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
        pre_x.append(preprocess_input(input))

        _, fn = os.path.split(s)
        y = class_indices.get(fn[:3])
        true_y.append(y)

    pre_x = np.array(pre_x)

    return pre_x, true_y

    
def plot_sonfusion_matrix(cm, classes, normalize=False, title=Confusion matrix, cmap=plt.cm.Blues):
    plt.imshow(cm, interpolation=nearest, cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    print(tick_marks, type(tick_marks))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks([-0.5,1.5], classes)

    print(cm)
    ok_num = 0
    for k in range(cm.shape[0]):
        print(cm[k,k]/np.sum(cm[k,:]))
        ok_num += cm[k,k]
        
    print(ok_num/np.sum(cm))
        
    if normalize:
        cm = cm.astype(float) / cm.sum(axis=1)[:, np.newaxis]

    thresh = cm.max() / 2.0
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j], horizontalalignment=center, color=white if cm[i, j] > thresh else black)

    plt.tight_layout()
    plt.ylabel(True label)
    plt.xlabel(Predict label)

测试结果为97.5%,较前面提高了1.3%:

[[1225   25]
 [  38 1212]]
0.98
0.9696
0.9748

技术图片

 

Keras猫狗大战八:resnet50预训练模型迁移学习,图片先做归一化预处理,精度提高到97.5%

标签:imp   height   sha   shape   数据   center   class   use   normal   

原文地址:https://www.cnblogs.com/zhengbiqing/p/12003660.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!