具有不同大小图像的Tensorflow输入数据集

时间:2018-08-23 10:30:27

标签: tensorflow input

我正在尝试使用具有不同大小的输入图像来训练全卷积神经网络。我可以通过遍历训练图像并在每次迭代时创建单个numpy输入来实现此目的,即

for image_input, label in zip(image_data, labels):
    train_input_fn = tf.estimator.inputs.numpy_input_fn(
                                         x= {"x":image_input},
                                         y=label,
                                         batch_size=1, 
                                         num_epochs=None,
                                         shuffle=False)
    fcn_classifier.train(input_fn=input_func_gen, steps=1)

但是,通过这种方式,在每一步浪费大量资源之后,可以保存并加载模型。我还尝试过使用生成器即

一次创建整个数据集
def input_func_gen():
    dataset = tf.data.Dataset.from_generator(generator=generator, 
                                  output_types=(tf.float32, tf.int32))
    dataset = dataset.batch(1)
    iterator = dataset.make_one_shot_iterator()
    return iterator.get_next()

def generator():
    filenames = ['building-d-mapimage-10-gt.png', 'building-dmapimage- 
                                                   16-gt.png']
    i = 0
    while i < len(filenames):        
        features, labels = loading.read_image_data(filenames[i])
        yield features, labels
        i += 1
        if i >= len(filenames):
            i = 0

然后

 fcn_classifier.train(input_fn=input_func_gen,
                      steps=100)   

但是,通过这种方式,训练变得非常慢,并且在第一次迭代后耗尽了内存,这表明数据集存在问题(在第一种情况下,如果使用单个输入,则训练运行必须更快)。生成器中特征的形状也是(1, image_height, image_width,3)。但是在模型中,我必须将它们重塑为4维张量,

input_shape = tf.shape(input)
input = tf.reshape(input, [1, input_shape[2], input_shape[3], 3])

而不是tf.reshape(input, [1, input_shape[1], input_shape[2], 3]),这表明输入的尺寸有些奇怪吗?在第一种情况下,我可以直接使用输入而无需重塑形状或其他任何内容?

1 个答案:

答案 0 :(得分:1)

我设法通过将input_func_gen更改为以下图片来解决图片尺寸变化的问题

def input_func_gen():
    load_path = '/path_to_images'
    data_set = 'dataset_to_use'
    image_data, labels = loading.load_image_data_grayscale(load_path,data_set)
    dataset = tf.data.Dataset.from_generator(lambda: 
                              itertools.zip_longest(image_data, labels),
                              output_types=(tf.float32, tf.int32),
                              output_shapes=(tf.TensorShape([1, None, None, 
                                             3]), tf.TensorShape([1, None])))
    dataset = dataset.repeat()
    iterator = dataset.make_one_shot_iterator()
    return iterator.get_next()
相关问题