码迷,mamicode.com
首页 > 其他好文 > 详细

读取自己的数据集

时间:2020-06-03 12:13:59      阅读:74      评论:0      收藏:0      [点我收藏+]

标签:not   tor   code   训练   归一化   处理   color   ati   16px   

  图像分类任务中,大多数教程是直接导入深度学习库中的数据集直接用于模型训练,如果采用自己的数据集,会难以下手,这篇博客主要介绍使用Tensorflow2.1或Keras来读取自己的数据集。

1、Tensorflow方法制作数据集

   Tensorflow制作数据集,主要用到tf.data进行操作。步骤为制作csv文件、读取csv、读取数据、数据处理。

需要用到的库

import os
import random
import glob
import csv
import tensorflow as tf

1.1 制作csv文件

# 创建csv文件,输入分别为路径和要创建的csv文件名
def build_csv(root, filename):
    # 对种类进行编号,相当于用0,1,2分别表示这三个水果类别
    name2label = {}
    for name in sorted(os.listdir(os.path.join(root))):
        # 判断文件夹下的对象是否是一个文件夹
        # 不是文件夹,直接进行下一次判断
        # 是文件夹,对该目录进行编号
        if not os.path.isdir(os.path.join(root, name)):
            continue
        name2label[name] = len(name2label.keys())
    # 准备从每个文件夹中读取图片路径与编号
    images = []
    # 遍历数据集中的每个文件夹
    for name in name2label.keys():
        # 读取所有的png,jpg,jpeg格式的文件
        images += glob.glob(os.path.join(root, name, *.png))
        images += glob.glob(os.path.join(root, name, *.jpg))
        images += glob.glob(os.path.join(root, name, *.jpeg))
    print(len(images), images)
    random.shuffle(images)
    # 创建并写csv文件
    with open(os.path.join(root, filename), mode=w, newline=‘‘) as f:
        writer = csv.writer(f)
        for img in images:
            # 更改路径的分隔符
            name = img.split(os.sep)[-2]
            label = name2label[name]
            writer.writerow([img, label])
        print(written into csv file:, filename)

1.2读取csv文件

# 输入分别为路径和刚刚创建的csv文件名
def load_csv(root, filename):
    images, labels = [], []
    with open(os.path.join(root, filename)) as f:
        reader = csv.reader(f)
        for row in reader:
            img, label = row
            label = int(label)
            images.append(img)
            labels.append(label)
    return images, labels

1.3将数据集转换为tf.data格式

# 读取csv文件
images, labels = load_csv(root, filename)
# 转换为tf.data格式
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
# 数据处理操作,其中preprocessing是需要自己编写的一个实现数据处理功能的函数
dataset = dataset.shuffle(1000).map(preprocess).batch(32)

1.4数据处理操作

# 输入为路径和标签
def preprocess(x, y):
    # 根据路径读取图片
    x = tf.io.read_file(x)
    # 将图片数值转换为张量
    x = tf.image.decode_jpeg(x, channels=3)
    # 更改尺寸
    x = tf.image.resize(x, [244, 244])
    # 归一化
    x = tf.cast(x, dtype=tf.float32) / 255.
    y = tf.convert_to_tensor(y)

    return x, y

2、Keras方法制作数据集

   Keras制作数据集,使用Keras进行导入数据集。使用keras导入数据集,过程简单方便。

需要用到的库

from keras.preprocessing.image import ImageDataGenerator

2.1读取数据

# 将照片[0-255]数据缩放为[0-1]
train_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)

# 训练集与验证集路径
train_dir = "train/"
validation_dir = "validation/"

# 生成了224x224的RGB图像,形状为[20,224,224,3]与二进制标签[20,]的批量,每个批量包含20个样本
train_generator = train_datagen.flow_from_directory(
    train_dir,                  # 训练集路径
    target_size=(224, 224),     # 训练集样本尺寸大小为(224, 224)
    batch_size=32,              # 训练集每批包含20个样本
    class_mode=‘categorical)    
validation_generator = test_datagen.flow_from_directory(
    validation_dir,
    target_size=(224, 224),
    batch_size=16,
    class_mode=‘categorical)

2.2 输入数据到模型

history = model.fit_generator(
    train_generator,           
    validation_data=validation_generator,
  ......
)

         

读取自己的数据集

标签:not   tor   code   训练   归一化   处理   color   ati   16px   

原文地址:https://www.cnblogs.com/zonghui/p/12561274.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!