Tensorflow padded_batch用于稀疏张量?

时间:2018-06-11 16:09:19

标签: tensorflow tensorflow-datasets

我有一个代码,就像那样

import tensorflow as tf
import numpy as np
sequences = np.array([[1,3,4],[5,6,7,8],[9,10,11,12,13],[14,15]])
def generator():
  for el in sequences:
    yield el, np.random.randn(3,5).astype('float32')

def parser(dense_tensor,spectrogram):
  labels = tf.contrib.layers.dense_to_sparse(dense_tensor)
  return spectrogram,labels


dataset = tf.data.Dataset().from_generator(generator, output_types= (tf.int64, tf.float32), output_shapes=([None],[None,None]))

dataset = dataset.map(lambda den, spec:  parser(den,spec)).batch(2)
iter = dataset.make_initializable_iterator()
spectrogram,labels = iter.get_next()

with tf.Session() as sess:
  sess.run(iter.initializer)
  while True:
    try:
      spar,spe = sess.run([labels,spectrogram])
      print(spar, spe.shape)
    except Exception as e:
      #print(e)
      break

我在使用tf.data获取语音到文本的标签和频谱图。我上面放了一个玩具示例,没关系,如果我有相同长度的语音信号,但对于批量不同长度的信号,我需要做padded_batch,但是dense_to_sparse不允许填充批处理,我可以使用任何解决方案使用padded_batch与稀疏张量?

1 个答案:

答案 0 :(得分:0)

import tensorflow as tf
import numpy as np
def generator():
  for el in sequences:
    yield el, np.random.randn(np.random.randint(1,4),5).astype('float32')
def parser(dense_tensor,spectrogram):
  #labels = tf.contrib.layers.dense_to_sparse(dense_tensor, eos_token=100)
  labels = dense_tensor
  return spectrogram,labels

dataset = tf.data.Dataset().from_generator(generator, output_types= (tf.int64, tf.float32), output_shapes=([None],[None,None]))
dataset = dataset.map(lambda den, spec:  parser(den,spec)).padded_batch(2, ([None,None],[None]),padding_values=(0. , tf.constant(100,dtype=tf.int64)))
iter = dataset.make_initializable_iterator()
spectrogram,labels = iter.get_next()
res = tf.contrib.layers.dense_to_sparse(labels,eos_token=100)
print(res)
with tf.Session() as sess:
  sess.run(iter.initializer)
  while True:
    try:
      spar,spe,res1 = sess.run([labels,spectrogram,res])
      print(res1, spar,spe)
    except Exception as e:
      #print(e)
      break