标签:images 灰度 str div tensor 获取文件 eal amp put
下面的代码是生成一个每个图片大小是227*227*1的tfrecord文件,label是这个类别的英文名。
原图片是256*256*3RGB型的.jpg文件,在制作数据集的时候由于对图片的颜色没有要求,所以为了节省空间,进行了灰度化处理。
import tensorflow as tf
import os
import sys
from PIL import Image
import numpy as np
# 数据集路径
TRAIN_DATASET_DIR = "E:/python文件/tensorflow_learn/MyNet/images/train/"
TEST_DATASET_DIR = "E:/python文件/tensorflow_learn/MyNet/images/test/"
# tfrecord文件存放路径
TFRECORD_DIR = "E:/python文件/tensorflow_learn/MyNet/images/"
# 类型名
classes = {"apple_scab", "black_rot", "cedar_apple_rust", "healthy"}
# 判断tfrecord文件是否存在
def _dataset_exists(tfrecord_dir):
for split_name in [‘train‘, ‘test‘]:
# 产生test.tfrecords和 train.tfrecords文件路径
output_filename = os.path.join(tfrecord_dir, split_name+‘.tfrecords‘)
if not tf.gfile.Exists(output_filename):
return False
return True
def int64_feature(values):
if not isinstance(values, (tuple, list)):
values = [values]
return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
def bytes_feature(values):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
# 获取该类别的所有文件
def _get_filenames_and_classes(dataset_dir):
photo_filename = []
for filename in os.listdir(dataset_dir):
# 获取文件路径
path = os.path.join(dataset_dir, filename)
photo_filename.append(path)
return photo_filename
# 把数据转换为TFRecord格式
def _convert_dataset(split_name, dataset_dir):
assert split_name in [‘train‘, ‘test‘]
with tf.Session() as sess:
output_filename = os.path.join(TFRECORD_DIR, split_name+‘.tfrecords‘)
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
for index, name in enumerate(classes):
if split_name == ‘train‘:
class_path = TRAIN_DATASET_DIR + name + ‘/‘
else:
class_path = TEST_DATASET_DIR + name + ‘/‘
filenames = _get_filenames_and_classes(class_path)
for i, img_name in enumerate(filenames):
sys.stdout.write(‘\r>>%s %s Convering image: %d/%d‘ % (split_name, name, i+1, len(filenames)))
print(str(img_name))
sys.stdout.flush()
image_data = Image.open(img_name)
image_data = image_data.resize((227, 227))
image_data = np.array(image_data.convert(‘L‘)) # 图片灰度化处理
img_raw = image_data.tobytes()
example = tf.train.Example(
features=tf.train.Features(
feature={
‘img_raw‘: tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
‘label‘: tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
}
)
)
tfrecord_writer.write(example.SerializeToString())
# tfrecord_writer.close()
# 判断tfrecord文件是否存在
if _dataset_exists(TFRECORD_DIR):
print("文件已存在")
else:
# 数据转换
_convert_dataset(‘test‘, TEST_DATASET_DIR)
_convert_dataset(‘train‘, TRAIN_DATASET_DIR)
print(‘生成tfrecord文件!‘)
标签:images 灰度 str div tensor 获取文件 eal amp put
原文地址:https://www.cnblogs.com/lyf98/p/11965256.html