新的dataset.map转换和查找表:不兼容的字符串类型

时间:2017-06-13 10:35:15

标签: python tensorflow

我期待使用1.2中提供的new Dataset API,但我在应用简单的map转换时会遇到问题,该转换会在index table中查找单词。

考虑这个简单的例子:

import tensorflow as tf

mapping_strings = tf.constant(["emerson", "lake", "palmer"])
table = tf.contrib.lookup.index_table_from_tensor(
    mapping=mapping_strings, num_oov_buckets=1)

dataset = tf.contrib.data.Dataset.from_tensor_slices(
    tf.constant(["emerson", "lake"]))

# Here is the map operation that generates an error.
dataset = dataset.map(lambda x: table.lookup(x))

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    sess.run(tf.tables_initializer())
    sess.run(next_element)

使用1.2.0-rc2,会产生以下错误:

TypeError: In op 'string_to_index_Lookup/hash_table_Lookup', input types ([tf.string, tf.string, tf.int64]) are not compatible with expected types ([tf.string_ref, tf.string, tf.int64])

查找表需要tf.string_ref,并且似乎无法满足此要求。

由于我是TensorFlow的新手,我不怀疑这是一个错误但是用法不好。我的错是什么?

谢谢!

编辑2017-06-15:但是,如果版本为nightly,则会引发其他错误:

ValueError: Cannot capture a stateful node (name:string_to_index/hash_table, type:HashTableV2) by value.

1 个答案:

答案 0 :(得分:3)

您可能希望使用Dataset.make_initializable_iterator()而不是Dataset.make_one_shot_iterator(),因为哈希表是有状态的。

以下代码为我工作:

import tensorflow as tf

mapping_strings = tf.constant(["emerson", "lake", "palmer"])
table = tf.contrib.lookup.index_table_from_tensor(
  mapping=mapping_strings, num_oov_buckets=1)

dataset = tf.contrib.data.Dataset.from_tensor_slices(
  tf.constant(["emerson", "lake"]))

# Here is the map operation that generates an error.
dataset = dataset.map(lambda x: table.lookup(x))

iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer

with tf.Session() as sess:
  sess.run(tf.tables_initializer())
  sess.run(init_op)
相关问题