Tensorflow:如何在Estimator中使用来自生成器的数据集

时间:2018-02-13 14:30:50

标签: tensorflow tensorflow-datasets

尝试构建简单模型,以弄清楚如何处理tf.data.Dataset.from_generator。我无法理解如何设置output_shapes参数。我尝试了几种组合,包括没有指定它,但由于张量的形状不匹配仍然会收到一些错误。这个想法只是产生两个带有SIZE = 10的numpy数组,并用它们运行线性回归。这是代码:

SIZE = 10


def _generator():
    feats = np.random.normal(0, 1, SIZE)
    labels = np.random.normal(0, 1, SIZE)
    yield feats, labels


def input_func_gen():
    shapes = (SIZE, SIZE)
    dataset = tf.data.Dataset.from_generator(generator=_generator,
                                             output_types=(tf.float32, tf.float32),
                                             output_shapes=shapes)
    dataset = dataset.batch(10)
    dataset = dataset.repeat(20)
    iterator = dataset.make_one_shot_iterator()
    features_tensors, labels = iterator.get_next()
    features = {'x': features_tensors}
    return features, labels


def train():
    x_col = tf.feature_column.numeric_column(key='x', )
    es = tf.estimator.LinearRegressor(feature_columns=[x_col])
    es = es.train(input_fn=input_func_gen)

另一个问题是,是否可以使用此功能为tf.feature_column.crossed_column的要素列提供数据?总体目标是在批处理培训中使用Dataset.from_generator功能,在数据不适合内存的情况下,数据从数据库加载到数据块。所有意见和例子都受到高度赞赏。

谢谢!

2 个答案:

答案 0 :(得分:10)

this github issue的可选output_shapes参数允许您指定从生成器中生成的值的形状。它的类型有两个约束,用于定义如何指定它:

  • output_shapes参数是一个"嵌套结构" (例如元组,元组元组,元组字典等)必须与生成器产生的值的结构相匹配。

    在您的计划中,_generator()包含声明yield feats, labels。因此"嵌套结构"是两个元素的元组(每个数组一个)。

  • output_shapes结构的每个组件应与相应张量的形状匹配。数组的形状始终是维度的元组。 (tf.Tensor的形状更为通用:请参阅tf.data.Dataset.from_generator()进行讨论。)让我们看一下feats的实际形状:

    >>> SIZE = 10
    >>> feats = np.random.normal(0, 1, SIZE)
    >>> print feats.shape
    (10,)
    

因此output_shapes参数应该是一个2元素的元组,其中每个元素都是(SIZE,)

shapes = ((SIZE,), (SIZE,))
dataset = tf.data.Dataset.from_generator(generator=_generator,
                                         output_types=(tf.float32, tf.float32),
                                         output_shapes=shapes)

最后,您需要向this Stack Overflow questiontf.feature_column.numeric_column() API提供有关形状的更多信息:

x_col = tf.feature_column.numeric_column(key='x', shape=(SIZE,))
es = tf.estimator.LinearRegressor(feature_columns=[x_col],
                                  label_dimension=10)

答案 1 :(得分:0)

@crafet是的,您不能用这种方式。您只能使用批处理大小来加快处理速度

相关问题