在Tensor Flow中访问文件队列中的文件名

时间:2015-12-02 19:21:42

标签: python tensorflow

我有一个图像目录,以及一个将图像文件名与标签匹配的单独文件。因此,图像目录中包含' train / 001.jpg'等文件。标签文件如下:

train/001.jpg 1
train/002.jpg 2
...

我可以通过从文件名创建文件队列,轻松地从Tensor Flow中的图像目录加载图像:

filequeue = tf.train.string_input_producer(filenames)
reader = tf.WholeFileReader()
img = reader.read(filequeue)

但我对如何将这些文件与标签文件中的标签结合起来感到茫然。看来我需要在每一步都访问队列中的文件名。有办法获得它们吗?此外,一旦我有文件名,我需要能够查找由文件名键入的标签。似乎标准的Python字典不起作用,因为这些计算需要在图中的每一步发生。

5 个答案:

答案 0 :(得分:12)

鉴于您的数据不是太大而无法提供文件名列表作为python数组,我建议您只使用Python进行预处理。创建文件名和标签的两个列表(相同顺序),并将它们插入randomshufflequeue或队列中,然后从中出列。如果你想要string_input_producer的“循环无限”行为,你可以在每个纪元的开头重新运行'enqueue'。

一个很好的玩具示例:

import tensorflow as tf

f = ["f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8"]
l = ["l1", "l2", "l3", "l4", "l5", "l6", "l7", "l8"]

fv = tf.constant(f)
lv = tf.constant(l)

rsq = tf.RandomShuffleQueue(10, 0, [tf.string, tf.string], shapes=[[],[]])
do_enqueues = rsq.enqueue_many([fv, lv])

gotf, gotl = rsq.dequeue()

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    tf.train.start_queue_runners(sess=sess)
    sess.run(do_enqueues)
    for i in xrange(2):
        one_f, one_l = sess.run([gotf, gotl])
        print "F: ", one_f, "L: ", one_l

关键是,当您执行enqueue时,您有效地将文件名/标签对排入队列,并且dequeue会返回这些对。

答案 1 :(得分:4)

这是我能够做到的。

我首先将文件名洗牌并在Python中将标签与它们匹配:

np.random.shuffle(filenames)
labels = [label_dict[f] for f in filenames]

然后为shuffle off创建了一个string_input_producer文件名,并为标签创建了一个FIFO:

lv = tf.constant(labels)
label_fifo = tf.FIFOQueue(len(filenames),tf.int32, shapes=[[]])
file_fifo = tf.train.string_input_producer(filenames, shuffle=False, capacity=len(filenames))
label_enqueue = label_fifo.enqueue_many([lv])

然后,为了阅读图像,我可以使用WholeFileReader并获得标签,我可以将fifo出列:

reader = tf.WholeFileReader()
image = tf.image.decode_jpeg(value, channels=3)
image.set_shape([128,128,3])
result.uint8image = image
result.label = label_fifo.dequeue()

按如下方式生成批次:

min_fraction_of_examples_in_queue = 0.4
min_queue_examples = int(num_examples_per_epoch *
                         min_fraction_of_examples_in_queue)
num_preprocess_threads = 16
images, label_batch = tf.train.shuffle_batch(
  [result.uint8image, result.label],
  batch_size=FLAGS.batch_size,
  num_threads=num_preprocess_threads,
  capacity=min_queue_examples + 3 * FLAGS.batch_size,
  min_after_dequeue=min_queue_examples)

答案 2 :(得分:1)

您可以使用tf.py_func()来实现从文件路径到标签的映射。

files = gfile.Glob(data_pattern)
filename_queue = tf.train.string_input_producer(
files, num_epochs=num_epochs, shuffle=True) #  list of files to read

def extract_label(s):
    # path to label logic for cat&dog dataset
    return 0 if os.path.basename(str(s)).startswith('cat') else 1

def read(filename_queue):
  key, value = reader.read(filename_queue)
  image = tf.image.decode_jpeg(value, channels=3)
  image = tf.cast(image, tf.float32)
  image = tf.image.resize_image_with_crop_or_pad(image, width, height)
  label = tf.cast(tf.py_func(extract_label, [key], tf.int64), tf.int32)
  label = tf.reshape(label, [])

training_data = [read(filename_queue) for _ in range(num_readers)]

...

tf.train.shuffle_batch_join(training_data, ...)

答案 3 :(得分:0)

我用过这个:

 filename = filename.strip().decode('ascii')

答案 4 :(得分:0)

另一个建议是以TFRecord格式保存您的数据。在这种情况下,您可以将所有图像和所有标签保存在同一文件中。对于大量文件,它具有很多优点:

  • 可以在同一个地方存储数据和标签
  • 数据在一个地方分配(无需记住各种目录)
  • 如果文件(图像)很多,打开/关闭文件非常耗时。从ssd / hdd寻找文件的位置也需要时间
相关问题