tf.decode_csv-有关参数的问题

时间:2019-06-25 15:11:25

标签: python tensorflow machine-learning computer-vision

我正在使用TensorFlow处理图像分类模型,并一直在检查所有代码以确保我理解它;除了输入功能的一部分,我对所有这些都了解。

在输入功能中,csv文件(火车/评估数据文件的名称)被转换为两个张量,每列一个。并将图像本身转换为二进制数据。

在父函数make_input_fn中,csv_row不是参数。嵌套在该父函数中的是_input_fn,而嵌套在其中的是DEcode_csv函数。

所以我不明白的是:csv_row不是make_input_fn中的参数,而是decode_csv函数的参数。代码如何知道-需要一种更好的放置方式-csv_row是什么?

我已经在其他地方看到过类似的代码,所以我知道它是正确的,但我只是想了解它的工作原理。

非常感谢任何帮助。


def make_input_fn(csv_of_filenames, batch_size, mode, augment = False):
    def _input_fn():
        def decode_csv(csv_row):
            filename, label = tf.decode_csv(records = csv_row, record_defaults = [[""],[""]])
            image_bytes = tf.read_file(filename = filename)
            return image_bytes, label

        # Create tf.data.dataset from filename
        dataset = tf.data.TextLineDataset(filenames = csv_of_filenames).map(map_func = decode_csv)     

        if augment: 
            dataset = dataset.map(map_func = read_and_preprocess_with_augment)
        else:
            dataset = dataset.map(map_func = read_and_preprocess)

        if mode == tf.estimator.ModeKeys.TRAIN:
            num_epochs = None # indefinitely
            dataset = dataset.shuffle(buffer_size = 10 * batch_size)
        else:
            num_epochs = 1 # end-of-input after this

        dataset = dataset.repeat(count = num_epochs).batch(batch_size = batch_size)
        return dataset.make_one_shot_iterator().get_next()
    return _input_fn

0 个答案:

没有答案