在Tensorflow中使用带有队列的While循环

时间:2017-07-18 01:58:02

标签: while-loop tensorflow queue

我有一组图像,我将在张量流中输入图形。通过FIFOQueue获取数据。问题是在某些图像中,未检测到脸部,即图像不包含脸部。因此,在将它们输入图形之前,我将忽略这些图像。我的代码如下:

import tensorflow as tf
import numpy as np

num_epoch = 100

tfrecords_filename_seq = ["C:/Users/user/PycharmProjects/AffectiveComputing/P16_db.tfrecords"]
filename_queue = tf.train.string_input_producer(tfrecords_filename_seq, num_epochs=num_epoch, shuffle=False, name='queue')
reader = tf.TFRecordReader()

current_image_confidence = tf.Variable(tf.constant(0.0, dtype=tf.float32))
image = tf.Variable(tf.ones([112, 112, 3]), dtype=tf.float32)
annotation = tf.Variable('', dtype=tf.string)

def body():
    key, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
        serialized_example,
        # Defaults are not specified since both keys are required.
        features={
            'height': tf.FixedLenFeature([], tf.int64),
            'width': tf.FixedLenFeature([], tf.int64),
            'image_raw': tf.FixedLenFeature([], tf.string),
            'annotation_raw': tf.FixedLenFeature([], tf.string)
        })

    # This is how we create one example, that is, extract one example from the database.
    image_ = tf.decode_raw(features['image_raw'], tf.uint8)
    # The height and the weights are used to
    height = tf.cast(features['height'], tf.int32)
    width = tf.cast(features['width'], tf.int32)

    # The image is reshaped since when stored as a binary format, it is flattened. Therefore, we need the
    # height and the weight to restore the original image back.
    image.assign(tf.reshape(image_, [height, width, 3]))

    annotation.assign(tf.cast(features['annotation_raw'], tf.string))
    current_image_confidence.assign(tf.slice(tf.string_to_number(tf.string_split(annotation, delimiter=','),
                                                            out_type=tf.float32),
                                        begin=[0, 3],
                                        size=[1, 1]))

def cond():
    tf.equal(current_image_confidence, tf.constant(0.0, dtype=tf.float32))

loop = tf.while_loop(cond, body, [current_image_confidence, reader, image, annotation])

因此,我需要一个while循环,直到我得到一张带脸的图像。那时我需要终止循环并将图像发送到图形。

请注意,我的数据存储在tfrecord文件中。因此每条记录包含一个图像和一组称为注释的特征,保存为tf.string。因此, current_image_confidence 变量用于根据面是否存在来保持值1或0。

如何修复代码??

非常感谢任何帮助!!

0 个答案:

没有答案