Tensorflow,如何连接具有不同批处理大小的多个数据集

时间:2018-10-19 17:20:02

标签: python tensorflow tensorflow-datasets

想象我有:

  • 具有数据[5、5、5、5、5]的数据集1
  • 具有数据[4,4]的数据集2

我想从两个数据集中提取批次并将它们连接起来,以便获得3号批次,其中:

  • 我读取的数据集1的批次大小为2
  • 我读取了批处理大小为1的数据集2。

如果某些数据集先清空,我也想读取最后一批。 在这种情况下,我将得到[5,5,4],[5,5,4],[5]作为我的最终结果。

我该怎么做? 我在这里看到了答案:Tensorflow how to generate unbalanced combined data sets

这是一个很好的尝试,但是如果其中一个数据集先排空,则它不起作用(因为当您尝试从先清空的数据集中获取元素时,tf.errors.OutOfRangeError会先行输出)我没有得到最后一批)。因此我只会得到[5,5,4],[5,5,4]

我考虑过使用tf.contrib.data.choose_from_datasets

ds1 = tf.data.Dataset.from_tensor_slices([5, 5, 5, 5, 5]).batch(2)
ds2 = tf.data.Dataset.from_tensor_slices([4, 4, 4, 4]).batch(1)
choice_dataset = [1, 2, 1, 2, 1]
ds = tf.contrib.data.choose_from_datasets([ds1, ds2], choice_dataset)
ds = ds.apply(tf.contrib.data.unbatch())
ds = ds.batch(3, drop_remainder=False)

这类作品,但相当不雅致(有批量和批处理);而且,我对批次中的确切内容并没有真正的控制权。 (例如,如果ds1为[7] * 7,批次大小为2,而ds2为[2,2],批次大小为1,则我将得到[7,7,1],[7,7,1],[7 ,7、7],但是如果我实际上想拥有[7、7、1],[7、7、1],[7、7],[7],该怎么办?即保持每个数据集中的元素数量固定

还有其他更好的解决方案吗?

我的另一个想法是尝试使用tf.data.Dataset.flat_map

ds1 = tf.data.Dataset.from_tensor_slices([5, 5, 5, 5, 5])
ds2 = tf.data.Dataset.from_tensor_slices([4, 4, 4, 4])
batch_sizes = [2, 1]
def concat(*inputs):
  concat = partial(functools.reduce, lambda x, y: x.concatenate(y))
  datasets = [tf.data.Dataset.from_tensors(input) for input in inputs]
  datasets = [dataset.batch(batch_size) for batch_size, dataset in zip(batch_sizes, datasets)]
  return concat(datasets)
dataset = (tf.data.Dataset
           .zip((ds1, ds2))
           .flat_map(_concat_and_batch)
           .batch(sum(batch_sizes)))

但它似乎不起作用。

3 个答案:

答案 0 :(得分:1)

这是一个解决方案。它有一些问题,但我希望它能满足您的需求。

想法如下:您将两个数据集分别进行批处理,将它们压缩在一起,然后执行map函数将每个压缩的元组合并为一个批处理(到目前为止,这与{{3}中的建议类似}和this的答案。)

您注意到的问题是,压缩仅适用于两个长度相同的数据集。否则,一个数据集会先消耗掉另一个,并且不使用其余未消耗的元素。

我的解决方法是将两个数据集连接到另一个无限虚拟数据集。该虚拟数据集仅包含您不会在实际数据集中显示的值。这样可以消除拉链问题。但是,您需要摆脱所有虚拟元素。通过过滤和映射可以很容易地做到这一点。

import tensorflow as tf

ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])

# we assume that this value will never occur in `ds1` and `ds2`:
UNUSED_VALUE = -1 

# an infinite dummy dataset:
dummy_ds = tf.data.Dataset.from_tensors(UNUSED_VALUE).repeat() 

# make `ds1` and `ds2` infinite:
ds1 = ds1.concatenate(dummy_ds)
ds2 = ds2.concatenate(dummy_ds)

ds1 = ds1.batch(2)
ds2 = ds2.batch(1)

# this is the solution mentioned in the links above
ds = tf.data.Dataset.zip((ds1,ds2))
ds = ds.map(lambda x1, x2: tf.concat((x1,x2),0))

# filter the infinite dummy tail:
ds = ds.filter(lambda x: tf.reduce_any(tf.not_equal(x,UNUSED_VALUE)))

# filter from batches the dummy elements:
ds = ds.map(lambda x: tf.boolean_mask(x,tf.not_equal(x,UNUSED_VALUE)))

此解决方案有两个主要问题:

(1)我们需要为UNUSED_VALUE提供一个值,我们确信该值不会出现在数据集中。我怀疑有一个解决方法,可能是通过使虚拟数据集由空张量(而不是具有恒定值的张量)组成,但是我还不知道该怎么做。

(2)尽管此数据集具有有限数量的元素,但以下循环将永远不会终止:

iter = ds.make_one_shot_iterator()
batch = iter.get_next()
sess = tf.Session()
while True:
    print(sess.run(batch))

原因是迭代器不断过滤出虚拟示例,而不知道何时停止。可以通过将上面的repeat()调用更改为repeat(n)来解决,其中n是一个您知道的数字,它长于两个数据集的长度之差。

答案 1 :(得分:1)

如果您不介意在构建新数据集期间运行会话,则可以执行以下操作:

import tensorflow as tf
import numpy as np

ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])

ds1 = ds1.batch(2)
ds2 = ds2.batch(1)

iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()

batch1 = iter1.get_next()
batch2 = iter2.get_next()

sess = tf.Session()

# define a generator that will sess.run both datasets, and will return the concatenation of both
def GetBatch():
    while True:
        try:
            b1 = sess.run(batch1)
        except tf.errors.OutOfRangeError:
            b1 = None
        try:
            b2 = sess.run(batch2)
        except tf.errors.OutOfRangeError:
            b2 = None
        if (b1 is None) and (b2 is None):
            break
        elif b1 is None:
            yield b2
        elif b2 is None:
            yield b1
        else:
            yield np.concatenate((b1,b2))

# create a dataset from the above generator
ds = tf.data.Dataset.from_generator(GetBatch,tf.int32)

请注意,如果需要,可以将上述会话隐藏\封装(例如,在函数内部),例如:

iter = ds.make_one_shot_iterator()
batch = iter.get_next()

sess2 = tf.Session()

while True:
    print(sess2.run(batch))

答案 2 :(得分:1)

这里是一个解决方案,要求您使用“控制输入”,选择要使用的批次,然后根据首先使用的数据集来决定。可以使用抛出的异常来检测到这一点。

为解释该解决方案,我将首先提出一种无效的尝试。

尝试的解决方案#1

import tensorflow as tf

ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])

ds1 = ds1.batch(2)
ds2 = ds2.batch(1)

iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()

batch1 = iter1.get_next(name='batch1')
batch2 = iter2.get_next(name='batch2')
batch12 = tf.concat((batch1, batch2), 0)

# this is a "control" placeholder. Its value determines whether to use `batch1`,`batch2` or `batch12`
which_batch = tf.placeholder(tf.int32)

batch = tf.cond(
               tf.equal(which_batch,0), # if `which_batch`==0, use `batch12`
                       lambda:batch12,
        lambda:tf.cond(tf.equal(which_batch,1), # elif `which_batch`==1, use `batch1`
                       lambda:batch1,
        lambda:batch2)) # else, use `batch2`

sess = tf.Session()

which = 0 # this value will be fed into the control placeholder `which_batch`
while True:
    try:
        print(sess.run(batch,feed_dict={which_batch:which}))
    except tf.errors.OutOfRangeError as e:
        # use the error to detect which dataset was consumed, and update `which` accordingly
        if which==0:
            if 'batch2' in e.op.name:
                which = 1
            else:
                which = 2
        else:
            break

该解决方案不起作用,因为对于which_batch的任何值,tf.cond()命令都会评估其分支的所有前身(请参见this answer)。因此,即使which_batch的值为1,也会计算batch2并抛出OutOfRangeError

尝试的解决方案2

可以通过将batch1batch2batch12的定义移到函数中来解决此问题。

import tensorflow as tf

ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])

ds1 = ds1.batch(2)
ds2 = ds2.batch(1)

iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()

def get_batch1():
    batch1 = iter1.get_next(name='batch1')
    return batch1

def get_batch2():
    batch2 = iter2.get_next(name='batch2')
    return batch2

def get_batch12():
    batch1 = iter1.get_next(name='batch1_')
    batch2 = iter2.get_next(name='batch2_')
    batch12 = tf.concat((batch1, batch2), 0)
    return batch12

# this is a "control" placeholder. It's value determines whether to ues `batch1`,`batch2` or `batch12`
which_batch = tf.placeholder(tf.int32)

batch = tf.cond(
               tf.equal(which_batch,0), # if `which_batch`==0, use `batch12`
                       get_batch12,
        lambda:tf.cond(tf.equal(which_batch,1), # elif `which_batch`==1, use `batch1`
                       get_batch1,
        get_batch2)) # elif `which_batch`==2, use `batch2`

sess = tf.Session()

which = 0 # this value will be fed into the control placeholder `which_batch`
while True:
    try:
        print(sess.run(batch,feed_dict={which_batch:which}))
    except tf.errors.OutOfRangeError as e:
        # use the error to detect which dataset was consumed, and update `which` accordingly
        if which==0:
            if 'batch2' in e.op.name:
                which = 1
            else:
                which = 2
        else:
            break

但是,这也不起作用。原因是在形成batch12并使用完数据集ds2的步骤中,我们从数据集ds1中提取了该批次,并在不使用它的情况下将其“丢弃”。

解决方案

我们需要一种机制来确保在消耗其他数据集的情况下不“丢弃”任何批次。为此,我们可以定义一个变量,该变量将被分配给当前批次的ds1,但只能在尝试获取 batch12之前分配。否则,此变量将保留其先前的值。然后,如果batch12由于消耗了ds1而失败,则此分配将失败并且batch2没有被丢弃,我们下次可以使用它。否则,如果batch12由于消耗了ds2而失败,那么我们在定义的变量中有batch1的备份,使用此备份后,我们可以继续进行{{ 1}}。

batch1