码迷,mamicode.com
首页 > 其他好文 > 详细

Fine-tune with Pretrained Models

时间:2020-06-07 13:03:18      阅读:67      评论:0      收藏:0      [点我收藏+]

标签:mod   html   代码   ram   val   分数   tca   ams   read   

Gluon版本微调见这里。基于NDarray,类似于Pytorch动态图。而module版本类似于TF,基于Symbol,用的是静态graph。一般静态图用于快速调试见效果,而静态图效率高,速度快,实际中应更多使用。

本文基于module和symbol。利用imagenet训好的模型来微调caltech-256数据集。首先是制作数据:

训练集每类随机采样60张图,其余作为验证集。将图像resize成256,并打包成rec文件:

wget http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar   # 下载解压
tar -xf 256_ObjectCategories.tar

mkdir -p caltech_256_train_60              # 划分数据集,训练集每类60张图
for i in 256_ObjectCategories/*; do
    c=`basename $i`
    mkdir -p caltech_256_train_60/$c
    for j in `ls $i/*.jpg | shuf | head -n 60`; do
        mv $j caltech_256_train_60/$c/
    done
done

python3 im2rec.py --list --recursive caltech-256-60-train caltech_256_train_60/
python3 im2rec.py --list --recursive caltech-256-60-val 256_ObjectCategories/
python3 im2rec.py --resize 256 --quality 90 --num-thread 16 caltech-256-60-val 256_ObjectCategories/
python3 im2rec.py --resize 256 --quality 90 --num-thread 16 caltech-256-60-train caltech_256_train_60/

这段代码实际上是新建了个文件夹,然后把所有数据剪切出去一个训练集,然后针对两份数据生成lst和rec文件。

当然也可以解压后直接这样生成:

tar -xf 256_ObjectCategories.tar
python3 im2rec.py --list --recursive --train-ratio 0.6 caltech-256-60 256_ObjectCategories

python3 im2rec.py --resize 256 --quality 90 --num-thread 16 caltech-256-60 256_ObjectCategories

注意在生成的时候可能会发生段错误:

技术图片

 

这是因为在我的电脑上执行resize 256的时候报错,当我把线程数--num-thread改为1的时候就可以了。或者可以resize 更小例如128的时候4线程就ok。生成的文件:

技术图片 

不想自己生成,官方也给出了这些文件的下载:

import os, sys

if sys.version_info[0] >= 3:
    from urllib.request import urlretrieve
else:
    from urllib import urlretrieve

def download(url):
    filename = url.split("/")[-1]
    if not os.path.exists(filename):
        urlretrieve(url, filename)
download(http://data.mxnet.io/data/caltech-256/caltech-256-60-train.rec)
download(http://data.mxnet.io/data/caltech-256/caltech-256-60-val.rec)

然后可以定义data iter:

import mxnet as mx

def get_iterators(batch_size, data_shape=(3, 224, 224)):
    train = mx.io.ImageRecordIter(
        path_imgrec         = ./caltech-256-60-train.rec,
        data_name           = data,
        label_name          = softmax_label,
        batch_size          = batch_size,
        data_shape          = data_shape,
        shuffle             = True,
        rand_crop           = True,
        rand_mirror         = True)
    val = mx.io.ImageRecordIter(
        path_imgrec         = ./caltech-256-60-val.rec,
        data_name           = data,
        label_name          = softmax_label,
        batch_size          = batch_size,
        data_shape          = data_shape,
        rand_crop           = False,
        rand_mirror         = False)
    return (train, val)

下载预训练的resnet18权重并载入。

def get_model(prefix, epoch):
    download(prefix+-symbol.json)
    download(prefix+-%04d.params % (epoch,))

get_model(http://data.mxnet.io/models/imagenet/resnet/50-layers/resnet-18, 0)
sym, arg_params, aux_params = mx.model.load_checkpoint(resnet-18, 0)

技术图片

然后可以开始训练:

首先定义一个函数替代最后的一层全连接:

def get_fine_tune_model(symbol, arg_params, num_classes, layer_name=flatten0):
    """
    symbol: the pretrained network symbol
    arg_params: the argument parameters of the pretrained model
    num_classes: the number of classes for the fine-tune datasets
    layer_name: the layer name before the last fully-connected layer
    """
    all_layers = symbol.get_internals()     # 得到所有层
    net = all_layers[layer_name+_output]     # 注意这里的操作很反直觉,这句话意思是一直取到名字为layer_name的层
    net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name=fc1)      # 新建一个分类层
    net = mx.symbol.SoftmaxOutput(data=net, name=softmax)              # 输出softmax概率
    new_args = dict({k:arg_params[k] for k in arg_params if fc1 not in k})     # 除了新的全连接层,载入已有的权重
    return (net, new_args)     # 返回新的网络symbol结构和参数

symbol是和module搭档的,有了symbol,就可以新建module来喂入数据:

import logging
head = %(asctime)-15s %(message)s
logging.basicConfig(level=logging.DEBUG, format=head)

def fit(symbol, arg_params, aux_params, train, val, batch_size, num_gpus):
    devs = [mx.gpu(i) for i in range(num_gpus)]
    mod = mx.mod.Module(symbol=symbol, context=devs)          # 新建一个module
    mod.fit(train, val,     # train和val的 data iter
        num_epoch=8,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback = mx.callback.Speedometer(batch_size, 10),        # 每10个批量后打印一次训练速度和评价指标metric的值
        kvstore=device, 
        optimizer=sgd,
        optimizer_params={learning_rate:0.01},
        initializer=mx.init.Xavier(rnd_type=gaussian, factor_type="in", magnitude=2),
        eval_metric=acc)
    metric = mx.metric.Accuracy()
    return mod.score(val, metric)

跑起来:

num_classes = 256
batch_per_gpu = 16
num_gpus = 8

(new_sym, new_args) = get_fine_tune_model(sym, arg_params, num_classes)     # 得到新的symbol和参数

batch_size = batch_per_gpu * num_gpus       # 计算总的批量
(train, val) = get_iterators(batch_size)       # 根据批量得到data iter
mod_score = fit(new_sym, new_args, aux_params, train, val, batch_size, num_gpus)    # 训练
assert mod_score > 0.77, "Low training accuracy." 

 

Fine-tune with Pretrained Models

标签:mod   html   代码   ram   val   分数   tca   ams   read   

原文地址:https://www.cnblogs.com/king-lps/p/13060039.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!