使用tf.data读取多个通用文件

时间:2019-03-26 23:35:16

标签: python python-3.x tensorflow tensorflow-datasets tensorflow2.0

我正在尝试使用tf.data实现输入管道。 功能位于从matlab导出的矩阵中,而标签位于其他需要特殊功能才能读取的文件中。

可以给定一个数字来计算必须加载的文件名。

这就是我的实现方式

def load_files(k):
    mesh_file = file_path(k, "off", flags.dataset_mesh)
    mat_file = file_path(k, "mat", flags.dataset_mat)

    mesh = pymesh.load_mesh(mesh_file)
    mat = scipy.io.loadmat(mat_file)

    return mesh.vertices, mat


def generator_fn():
    return (load_files(x) for x in range(1000000 + 1))


def input_fn() -> Dataset:
    dataset = tf.data.Dataset.from_generator(generator_fn,
           output_types=(tf.as_dtype(tf.float32), tf.as_dtype(tf.float32)), )
    dataset = dataset.batch(batch_size=flags.batch_size).repeat()
    dataset = dataset.cache()
    dataset = dataset.prefetch(buffer_size=flags.prefetch_buffer_size)
    return dataset

问题在于GPU使用率非常低,大约为5%(2080 ti)。我不确定瓶颈在哪里。 我正在使用简单的MLP进行测试,但是尽管添加了层或每层神经元,但gpu的使用似乎并没有改变。

我以这种方式进行训练:

model = keras.Sequential([
    keras.layers.Flatten(input_shape=(n_input,)),
    keras.layers.Dense(1024, activation=tf.nn.relu),
    .
    .
    .
    keras.layers.Dense(1024, activation=tf.nn.relu),
    keras.layers.Dense(n_output, activation=None)
])

model.compile(optimizer='adam', loss='mean_squared_error')

model.fit(input_fn().make_one_shot_iterator(), steps_per_epoch=1000000, epochs=1)

所以,我认为问题可能出在以下方面:关于如何馈送数据(由于我在SSD NVMe上,问题不应该仅仅是文件读取),关于如何进行培训,或者尽管我添加了各层,但这只是一个简单的网络。

但是,我想知道是否有一种更有效的方式来馈送数据。

我正在使用tensorflow-gpu 2.0.0a0,我从lambda-labs运行了一个基准,它能够100%使用gpu

0 个答案:

没有答案
相关问题