是否可以将NumPy函数映射到tf.data.dataset?

时间:2019-04-23 08:09:23

标签: python numpy tensorflow

我有以下简单代码:

arrCols = Array("APPLE", "ORANGE")
With Sheet1
    For i = LBound(arrCols) To UBound(arrCols)
        Set colheader = .Rows(1).Find(arrCols(i), , xlValues, xlWhole, xlByColumns, xlNext, False)
        Debug.Print arrCols(i) & "=" & colheader.Address
        Set colDiff = .Rows("1:2").Find("test", .Cells(1, colheader.Column), xlValues, xlWhole, xlByColumns, xlNext, False)
        Debug.Print "test =" & colDiff.Address
    Next
End With

基本上,代码创建了一个import tensorflow as tf import numpy as np filename = # a list of wav filenames x = tf.placeholder(tf.string) def mfcc(x): feature = # some function written in NumPy to convert a wav file to MFCC features return feature mfcc_fn = lambda x: mfcc(x) # create a training dataset train_dataset = tf.data.Dataset.from_tensor_slices((x)) train_dataset = train_dataset.repeat() train_dataset = train_dataset.map(mfcc_fn) train_dataset = train_dataset.batch(100) train_dataset = train_dataset.prefetch(buffer_size=1) # create an iterator and iterate over training dataset iterator = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes) train_iterator = iterator.make_initializer(train_dataset) with tf.Session() as sess: sess.run(train_iterator, feed_dict={x: filename}) 对象,该对象加载了wav文件并将其转换为mfcc功能。在这里,数据转换发生在tf.data.dataset处,在这里我将用NumPy编写的mfcc函数应用于所有输入数据。

显然,该代码在这里不起作用,因为NumPy不支持对train_dataset.map(mfcc_fn)对象的操作。如果必须在NumPy中编写函数,是否可以将函数映射为输入到tf.placeholder?我不使用TensorFlow的内置MFCC功能转换的原因是因为TensorFlow中的FFT函数提供的输出与其NumPy对应的输出显着不同(如illustraded here),并且我正在构建的模型倾向于MFCC功能使用NumPy生成。

1 个答案:

答案 0 :(得分:1)

您可以使用tf.py_func函数或tf.py_function(较新的版本)来实现。它完全可以满足您的要求,它将在数组中包装在数组上运行的numpy函数,您可以将其包含在数据集图中。