数据集,包含来自tf.dynamic_partition的可变大小的项目

时间:2018-02-06 21:02:08

标签: python tensorflow tensorflow-datasets

this question类似,我想从列表中构建一个TF dataset,每个元素都有不同的大小。但是,与链接的问题不同,我想从tf.dynamic_partition的输出生成数据集,该输出会输出张量列表。

我的设置:

import tensorflow as tf
D = tf.data.Dataset # shorthand notation

x = tf.range(9) # Array to be partitioned
p = tf.constant([1,0,2,0,0,0,2,2,1]) # Defines partitions

因此,数据集应包含三个元素,分别包含[1 3 4 5][0 8][2 6 7]

正如预期的那样,直接方法失败了:

dataset = D.from_tensor_slices(tf.dynamic_partition(x,p,3))
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
    nl = sess.run(next_element)
  

tensorflow.python.framework.errors_impl.InvalidArgumentError:形状   所有输入必须匹配:values [0] .shape = [4]!= values [1] .shape =   [2]

我尝试的下一件事是申请solution of the linked question,应用from_generator

dataset = D.from_generator(lambda: tf.dynamic_partition(x,p,3), tf.int32, output_shapes=[None])
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
    nl = sess.run(next_element)
  

tensorflow.python.framework.errors_impl.InvalidArgumentError:   exceptions.ValueError:使用序列设置数组元素。

如何从tf.dynamic_partition的输出创建包含可变大小项目的数据集?

1 个答案:

答案 0 :(得分:3)

cout << "BST with size = " << size << "[ "; 不起作用,因为它期望生成器函数产生numpy数组而不是张量。

解决问题的一种方法是为分区的每个元素创建一个数据集。在您的情况下,您将数据分为3组,因此您将创建3个数据集并将它们与tf.data.Dataset.concatenate()组合:

from_generator