码迷,mamicode.com
首页 > 其他好文 > 详细

使用data_flow_ops构造batch数据集

时间:2019-08-20 23:46:30      阅读:98      评论:0      收藏:0      [点我收藏+]

标签:port   标签   表示   city   print   队列   aci   input   import   

1. tf.unstack(number, axis=0)  表示对数据进行拆分

import tensorflow as tf
import numpy as np


data = np.array([[1, 2, 3],
                 [2, 3, 4],
                 [4, 5, 6]])

filenames = tf.unstack(data)  #表示输入的数据
with tf.Session() as sess:
    for filename in filenames: 
        print(sess.run(filename))
# [1, 2, 3]
# [4, 5, 6]
# [7, 8, 9]

 

对数据进行合理的解读

import tensorflow as tf
from tensorflow.python.ops import data_flow_ops
import numpy as np

# 构造初始的数据
image_paths_placeholder = tf.placeholder(tf.string, shape=(None, 3), name=image_path)
label_paths_placeholder = tf.placeholder(tf.int32, shape=(None, 3), name=labels)
# 构造输入的队列
input_queue = data_flow_ops.FIFOQueue(capacity=3,
                                      dtypes=[tf.string, tf.int32],
                                      shapes=([3, ], [3, ]),
                                      shared_name=None, name=None)
# 将数据放入
enqueue_op = input_queue.enqueue_many([image_paths_placeholder, label_paths_placeholder])
# 进行变量初始化
init = tf.global_variables_initializer()

X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
Y = np.array([[1, 1, 1], [2, 2, 2], [3, 3, 3]])
filename_labels = []
with tf.Session() as sess:

    # 将数据进行打包输出
    filenames, labels = input_queue.dequeue()
    # print(sess.run(filenames))
    images = []
    for filename in tf.unstack(filenames): # 将数据集按照axis=0进行拆分
        images.append(filename) # 将数据进行拆分, 这里可以对图片进行处理
        # print(sess.run(filename))
    filename_labels.append([images, labels]) # 将图片和标签进行添加
    #
    # # 使用图片和标签构造batch_size数据集
    image_batch, label_batch = tf.train.batch_join(
        filename_labels, batch_size=1,
        shapes=[(), ()], enqueue_many=True,
        capacity= 4 * 10,
        allow_smaller_final_batch=True
    )
    image_batch = tf.identity(image_batch, image_batch)
    enqueue_op.run(feed_dict={image_paths_placeholder: X, label_paths_placeholder: Y})

    x = sess.run([image_batch])
    print(1)
    # print(sess.run(image_batch))
    # 将数据进行输入

 

使用data_flow_ops构造batch数据集

标签:port   标签   表示   city   print   队列   aci   input   import   

原文地址:https://www.cnblogs.com/my-love-is-python/p/11386184.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!