码迷,mamicode.com
首页 > Web开发 > 详细

Keras猫狗大战七:resnet50预训练模型迁移学习优化,动态调整学习率,精度提高到96.2%

时间:2019-11-30 22:50:49      阅读:344      评论:0      收藏:0      [点我收藏+]

标签:tick   load   ros   iter   normal   迁移   卷积层   https   mod   

https://www.cnblogs.com/zhengbiqing/p/11780161.html中直接在resnet网络的卷积层后添加一层分类层,得到一个最简单的迁移学习模型,得到的结果为95.3%。

这里对最后的分类网络做些优化:用GlobalAveragePooling2D替换Flatten、增加一个密集连接层(同时添加BN、Activation、Dropout):

conv_base = ResNet50(weights=imagenet, include_top=False, input_shape=(150, 150, 3))
for layers in conv_base.layers[:]:
    layers.trainable = False
    
x = conv_base.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024)(x)
x = BatchNormalization()(x)
x = Activation(relu)(x)
x = Dropout(0.3)(x)
predictions = Dense(1, activation=sigmoid)(x)
model = Model(inputs=conv_base.input, outputs=predictions)

另外采用动态学习率,并且打印显示出学习率:

optimizer = optimizers.RMSprop(lr=1e-3)

def get_lr_metric(optimizer):
    def lr(y_true, y_pred):
        return optimizer.lr

    return lr

lr_metric = get_lr_metric(optimizer)

model.compile(loss=binary_crossentropy, optimizer=optimizer, metrics=[acc,lr_metric])

当模型的val_loss训练多轮不再下降时,提前结束训练:

from keras.callbacks import ReduceLROnPlateau,EarlyStopping

early_stop = EarlyStopping(monitor=val_loss, patience=13)
reduce_lr = ReduceLROnPlateau(monitor=val_loss, patience=7, mode=auto, factor=0.2)
callbacks = [early_stop,reduce_lr]

history = model.fit_generator(
      train_generator,
      steps_per_epoch=train_generator.samples//batch_size,
      epochs=100,
      validation_data=validation_generator,
      validation_steps=validation_generator.samples//batch_size,
      callbacks = callbacks)

共训练了61epochs,学习率从0.001下降到1.6e-6:

Epoch 1/100
281/281 [==============================] - 141s 503ms/step - loss: 0.3322 - acc: 0.8589 - lr: 0.0010 - val_loss: 0.2344 - val_acc: 0.9277 - val_lr: 0.0010
Epoch 2/100
281/281 [==============================] - 79s 279ms/step - loss: 0.2591 - acc: 0.8862 - lr: 0.0010 - val_loss: 0.2331 - val_acc: 0.9288 - val_lr: 0.0010
Epoch 3/100
281/281 [==============================] - 78s 279ms/step - loss: 0.2405 - acc: 0.8959 - lr: 0.0010 - val_loss: 0.2292 - val_acc: 0.9303 - val_lr: 0.0010
......
281/281 [==============================] - 77s 275ms/step - loss: 0.1532 - acc: 0.9407 - lr: 1.6000e-06 - val_loss: 0.1871 - val_acc: 0.9412 - val_lr: 1.6000e-06
Epoch 60/100
281/281 [==============================] - 78s 276ms/step - loss: 0.1492 - acc: 0.9396 - lr: 1.6000e-06 - val_loss: 0.1687 - val_acc: 0.9450 - val_lr: 1.6000e-06
Epoch 61/100
281/281 [==============================] - 77s 276ms/step - loss: 0.1468 - acc: 0.9414 - lr: 1.6000e-06 - val_loss: 0.1825 - val_acc: 0.9454 - val_lr: 1.6000e-06

加载模型:
optimizer = optimizers.RMSprop(lr=1e-3)

def get_lr_metric(optimizer):
    def lr(y_true, y_pred):
        return optimizer.lr

    return lr

lr_metric = get_lr_metric(optimizer)
model = load_model(model_file, custom_objects={lr:lr_metric})

修改混淆矩阵函数,以打印每个类别的精确度:

def plot_sonfusion_matrix(cm, classes, normalize=False, title=Confusion matrix, cmap=plt.cm.Blues):
    plt.imshow(cm, interpolation=nearest, cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks([-0.5,1.5], classes)

    print(cm)
    ok_num = 0
    for k in range(cm.shape[0]):
        print(cm[k,k]/np.sum(cm[k,:]))
        ok_num += cm[k,k]
        
    print(ok_num/np.sum(cm))
        
    if normalize:
        cm = cm.astype(float) / cm.sum(axis=1)[:, np.newaxis]

    thresh = cm.max() / 2.0
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j], horizontalalignment=center, color=white if cm[i, j] > thresh else black)

    plt.tight_layout()
    plt.ylabel(True label)
    plt.xlabel(Predict label)

测试结果为:

[[1200   50]
 [  45 1205]]
0.96
0.964
0.962
猫的准确度为96%,狗的为96.4%,总的准确度为96.2%。混淆矩阵图:

技术图片

Keras猫狗大战七:resnet50预训练模型迁移学习优化,动态调整学习率,精度提高到96.2%

标签:tick   load   ros   iter   normal   迁移   卷积层   https   mod   

原文地址:https://www.cnblogs.com/zhengbiqing/p/11964301.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!