tf.estimator.train_and_evaluate的input_fn优化

时间:2018-02-25 18:26:12

标签: python performance tensorflow tensorflow-datasets tensorflow-estimator

我正在构建一个TensorFlow Estimator,我希望使用tf.estimator.train_and_evaluate()函数进行训练和评估。此函数的doc提供以下建议:

  

在进行评估之前,还建议对模型进行更长时间的训练,例如多个时期,因为输入管道从头开始进行每次训练。

这是有道理的,因为train_and_evaluate()通过在调用estimator.train()estimator.evaluate()之间交替工作,拆掉每个新调用的计算图。在我的情况下,这是一个问题,因为我想相对经常评估模型,而我的input_fn似乎在设置中有很多开销。它目前看起来像这样:

def input_fn():
    # Build dataset from generator
    dataset = tf.data.Dataset.from_generator(
        generator=instance_generator,
        output_types=types,
        output_shapes=shapes,
    )

    dataset = dataset.shuffle(buffer_size=dataset_size)
    dataset = dataset.repeat(epochs_per_eval)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(1)

    return dataset

我怀疑此功能的大部分时间都来自于混洗,因为它需要先生成整个数据集。改组可能并不慢,但我的instance_generator是。理想情况下,我想找到一种方法,避免每次列车/评估呼叫都必须从发电机重建数据集。有什么办法可以使用Dataset类来实现这个目的吗?有没有办法可以在生成数据集之后缓存数据集的状态,以便在第一次调用之后对input_fn的每次新调用都变得更便宜?

1 个答案:

答案 0 :(得分:0)

也许你可以使用除tf.data.Dataset.from_generator之外的tf.data.Dataset.range。以下是示例代码: 首先,定义Python类

import tensorflow as tf
import time

class instance_generator():
    def __init__(self):
        #doing some initialization
        self.data_index = {n:str(n) for n in range(1000)}# create index othre than pretreat data

    def _hard_work(self, n):
        time.sleep(1) #doing the pretreating work
        return self.data_index[n]

    def __call__(self):
        def get_by_index(i):
            return tf.py_func(lambda i: self._hard_work(i), inp=[i], Tout=types)

        dataset = tf.data.Dataset.range(len(self.data_index))
        dataset = dataset.shuffle(buffer_size=dataset_size)
        dataset = dataset.repeat(epochs_per_eval)
        dataset = dataset.map(get_by_index)
        dataset = dataset.batch(batch_size)
        dataset = dataset.prefetch(1)
        return dataset.make_one_shot_iterator().next()

然后,将instance_generator类提供给tf.estimator:

data_train = instance_generator('train')
data_eval = instance_generator('eval')
model = tf.estimator.DNNClassifier(...)
tf.estimator.train_and_evaluate(
    estimator=model,
    train_spec=tf.estimator.TrainSpec(data_train),
    eval_spec=tf.estimator.Estimator(data_eval)
)

如果初始化步骤耗时,则只运行一次,每当估算器创建一个新图形时,它就会生成数据集。 如果数据预处理非常耗时,则它仅适用于数据的输送批次,而不适用于整个数据集。索引上的随机播放和重复非常便宜。 希望它有所帮助。

相关问题