我正在尝试做什么
我正在尝试使用基于https://github.com/ry/tensorflow-resnet的残余网络为我自己的图像提取CNN功能。我计划在探索如何将图像转换为单个文件之前从JPG文件输入图像数据。
我做了什么
我已阅读https://www.tensorflow.org/versions/r0.9/how_tos/reading_data/index.html以及一些有关如何输入数据的相关资料,例如Feed和占位符。这是我的代码:
import tensorflow as tf
from convert import print_prob, checkpoint_fn, meta_fn
from image_processing import image_preprocessing
tf.app.flags.DEFINE_integer('batch_size', 1, "batch size")
tf.app.flags.DEFINE_integer('input_size', 224, "input image size")
tf.app.flags.DEFINE_integer('min_after_dequeue', 224, "min after dequeue")
tf.app.flags.DEFINE_integer('layers', 152, "The number of layers in the net")
tf.app.flags.DEFINE_integer('image_number', 6951, "number of images")
FLAGS = tf.app.flags.FLAGS
def placeholder_inputs():
images_placeholder = tf.placeholder(tf.float32, shape=(FLAGS.batch_size, FLAGS.input_size, FLAGS.input_size, 3))
label_placeholder = tf.placeholder(tf.int32, shape=FLAGS.batch_size)
return images_placeholder, label_placeholder
def fill_feed_dict(image_ba, label_ba, images_pl, labels_pl):
feed_dict = {
images_pl: image_ba,
}
return feed_dict
min_fraction_of_examples_in_queue = 0.4
min_queue_examples = int(FLAGS.image_number *
min_fraction_of_examples_in_queue)
dataset = tf.train.string_input_producer(["hollywood_test.txt"])
reader = tf.TextLineReader()
_, file_content = reader.read(dataset)
image_name, label, _ = tf.decode_csv(file_content, [[""], [""], [""]], " ")
label = tf.string_to_number(label)
num_preprocess_threads = 10
images_and_labels = []
with tf.Session() as sess:
for thread_id in range(num_preprocess_threads):
image_buffer = tf.read_file(image_name)
bbox = []
train = False
image = image_preprocessing(image_buffer, bbox, train, thread_id)
image = image_buffer
images_and_labels.append([image, label])
image_batch, label_batch = tf.train.batch_join(images_and_labels,
batch_size=FLAGS.batch_size,
capacity=min_queue_examples + 3 * FLAGS.batch_size)
images_placeholder, labels_placeholder = placeholder_inputs()
new_saver = tf.train.import_meta_graph(meta_fn(FLAGS.layers))
new_saver.restore(sess, checkpoint_fn(FLAGS.layers))
graph = tf.get_default_graph()
prob_tensor = graph.get_tensor_by_name("prob:0")
images = graph.get_tensor_by_name("images:0")
feed_dict = fill_feed_dict(image_batch, label_batch, images, labels_placeholder)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
sess.run(tf.initialize_all_variables())
prob = sess.run(prob_tensor, feed_dict=feed_dict)
print_prob(prob[0])
coord.request_stop()
coord.join(threads)
我的问题是什么
我很困惑为什么feed_dict不支持张量作为输入。我发现在张量流的mnist示例中,当tensorflow提供批处理方法时,甚至还有另一个函数来生成批处理。但是batch_join函数返回一个张量,所以我不知道如何正确地提供我的结果。此代码出现错误TypeError: The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, or numpy ndarrays.