将Tensorflow数据集API创建的数据集拆分为Train and Test?

时间:2018-01-11 18:34:30

标签: tensorflow dataset

有谁知道如何将Tensorflow中数据集API(tf.data.Dataset)创建的数据集拆分为Test and Train?

9 个答案:

答案 0 :(得分:18)

假设您有all_dataset变量tf.data.Dataset类型:

test_dataset = all_dataset.take(1000) 
train_dataset = all_dataset.skip(1000)

测试数据集现在有前1000个元素,其余的用于训练。

答案 1 :(得分:7)

此处大多数答案使用take()skip(),这需要事先了解数据集的大小。这并非总是可能,或者很难/难以确定。

实际上,您可以做的是对数据集进行切片,以使每N条记录中有1条成为验证记录。

要做到这一点,让我们从0-9的简单数据集开始:

dataset = tf.data.Dataset.range(10)
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

现在在我们的示例中,我们将对其进行切片,以使我们得到3/1训练/验证拆分。意思是3条记录将进行训练,然后1条记录进行验证,然后重复。

split = 3
dataset_train = dataset.window(split, split + 1).flat_map(lambda ds: ds)
# [0, 1, 2, 4, 5, 6, 8, 9]
dataset_validation = dataset.skip(split).window(1, split + 1).flat_map(lambda ds: ds)
# [3, 7]

因此,第一个dataset.window(split, split + 1)说要获取split个元素的数量(3),然后前进split + 1个元素,然后重复。 + 1有效地跳过了我们将在验证数据集中使用的1元素。
flat_map(lambda ds: ds)是因为window()批量返回结果,这是我们不想要的。因此,我们将其压平。

然后我们首先获取验证数据,skip(split)会跳过在第一个训练窗口中获取的元素的前split个数字(3),因此我们在第四个元素上开始迭代。 window(1, split + 1)然后抓取1个元素,前进split + 1 (4),然后重复。

关于嵌套数据集的说明:
上面的示例适用于简单的数据集,但是如果嵌套数据集,flat_map()将产生错误。为了解决这个问题,您可以将flat_map()换成可以处理简单和嵌套数据集的更复杂的版本:

.flat_map(lambda *ds: ds[0] if len(ds) == 1 else tf.data.Dataset.zip(ds))

答案 2 :(得分:4)

您可以使用Dataset.take()Dataset.skip()

train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)

full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(val_size)
test_dataset = test_dataset.take(test_size)

为了更笼统,我举了一个使用70/15/15火车/ val /测试划分的示例,但是如果您不需要测试或val集,则只需忽略最后两行即可。

Take

  

从该数据集中创建一个最多包含count个元素的数据集。

Skip

  

创建一个数据集,该数据集从该数据集中跳过计数元素。

您可能还想研究Dataset.shard()

  

创建一个仅包含此数据集1 / num_shards的数据集。


免责声明我在回答this one之后偶然发现了这个问题,所以我以为我会传播爱心

答案 3 :(得分:3)

@ted的答案将引起某些重叠。试试这个。

train_ds_size = int(0.64 * full_ds_size)
valid_ds_size = int(0.16 * full_ds_size)

train_ds = full_ds.take(train_ds_size)
remaining = full_ds.skip(train_ds_size)  
valid_ds = remaining.take(valid_ds_size)
test_ds = remaining.skip(valid_ds_size)

使用下面的代码进行测试。

tf.enable_eager_execution()

dataset = tf.data.Dataset.range(100)

train_size = 20
valid_size = 30
test_size = 50

train = dataset.take(train_size)
remaining = dataset.skip(train_size)
valid = remaining.take(valid_size)
test = remaining.skip(valid_size)

for i in train:
    print(i)

for i in valid:
    print(i)

for i in test:
    print(i)

答案 4 :(得分:0)

现在Tensorflow不包含任何工具。
您可以使用sklearn.model_selection.train_test_split生成训练/评估/测试数据集,然后分别创建tf.data.Dataset

答案 5 :(得分:0)

您可以使用shard

dataset = dataset.shuffle()  # optional
trainset = dataset.shard(2, 0)
testset = dataset.shard(2, 1)

请参阅: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shard

答案 6 :(得分:0)

在已知数据集大小的情况下:

from typing import Tuple
import tensorflow as tf

def split_dataset(dataset: tf.data.Dataset, 
                  dataset_size: int, 
                  train_ratio: float, 
                  validation_ratio: float) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]:
    assert (train_ratio + validation_ratio) < 1

    train_count = int(dataset_size * train_ratio)
    validation_count = int(dataset_size * validation_ratio)
    test_count = dataset_size - (train_count + validation_count)

    dataset = dataset.shuffle(dataset_size)

    train_dataset = dataset.take(train_count)
    validation_dataset = dataset.skip(train_count).take(validation_count)
    test_dataset = dataset.skip(validation_count + train_count).take(test_count)

    return train_dataset, validation_dataset, test_dataset

示例:

size_of_ds = 1001
train_ratio = 0.6
val_ratio = 0.2

ds = tf.data.Dataset.from_tensor_slices(list(range(size_of_ds)))
train_ds, val_ds, test_ds = split_dataset(ds, size_of_ds, train_ratio, val_ratio)

答案 7 :(得分:-1)

@ apatsekin,@ ted最近我的声誉不超过50,所以我只需要在这里回答答案,我想直接使用.take方法获取或不获取测试数据集是否合理。如果数据集在每个纪元都经过了改组,那么它将得到不同的TRAIN / TEST划分,因为在训练过程中,我们需要测试集永远不会出现在训练集中。所以这应该是一个问题

或者我们在shuffle中添加一个参数:

train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)

full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle( reshuffle_each_iteration = False )
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(val_size)
test_dataset = test_dataset.take(test_size)

答案 8 :(得分:-2)

无法发表评论,但以上答案有重叠且不正确。将BUFFER_SIZE设置为DATASET_SIZE以获得完美的随机播放。尝试使用其他大小的val / test大小进行验证。答案应该是:

DATASET_SIZE = tf.data.experimental.cardinality(full_dataset).numpy()
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)

full_dataset = full_dataset.shuffle(BUFFER_SIZE)
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.take(val_size)
test_dataset = test_dataset.skip(val_size)